diff --git a/d2lsp/d2lsp.go b/d2lsp/d2lsp.go index 6704fdd6d..f233d43b5 100644 --- a/d2lsp/d2lsp.go +++ b/d2lsp/d2lsp.go @@ -2,7 +2,6 @@ package d2lsp import ( - "errors" "fmt" "strings" @@ -11,11 +10,11 @@ import ( "oss.terrastruct.com/d2/lib/memfs" ) -func GetFieldRefs(path string, fs map[string]string, key string) (refs []d2ir.Reference, _ error) { - if _, ok := fs["index"]; !ok { - return nil, errors.New("index not found") +func GetFieldRefs(path, index string, fs map[string]string, key string) (refs []d2ir.Reference, _ error) { + if _, ok := fs[index]; !ok { + return nil, fmt.Errorf(`"%s" not found`, index) } - r := strings.NewReader(fs["index"]) + r := strings.NewReader(fs[index]) ast, err := d2parser.Parse(path, r, nil) if err != nil { return nil, err @@ -46,7 +45,7 @@ func GetFieldRefs(path string, fs map[string]string, key string) (refs []d2ir.Re for _, p := range mk.Key.Path { f = curr.GetField(p.Unbox().ScalarString()) if f == nil { - return nil, fmt.Errorf(`"%s" not found`, key) + return nil, nil } curr = f.Map() } diff --git a/d2lsp/d2lsp_test.go b/d2lsp/d2lsp_test.go index d5790ea25..2562339f5 100644 --- a/d2lsp/d2lsp_test.go +++ b/d2lsp/d2lsp_test.go @@ -13,17 +13,46 @@ x.a a.x x -> y` fs := map[string]string{ - "index": script, + "index.d2": script, } - refs, err := d2lsp.GetFieldRefs("", fs, "x") + refs, err := d2lsp.GetFieldRefs("", "index.d2", fs, "x") assert.Success(t, err) assert.Equal(t, 3, len(refs)) assert.Equal(t, 0, refs[0].AST().GetRange().Start.Line) assert.Equal(t, 1, refs[1].AST().GetRange().Start.Line) assert.Equal(t, 3, refs[2].AST().GetRange().Start.Line) - refs, err = d2lsp.GetFieldRefs("", fs, "a.x") + refs, err = d2lsp.GetFieldRefs("", "index.d2", fs, "a.x") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) } + +func TestGetRefsImported(t *testing.T) { + fs := map[string]string{ + "index.d2": ` +...@ok +hi +`, + "ok.d2": ` +okay +`, + } + refs, err := d2lsp.GetFieldRefs("", "index.d2", fs, "hi") + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) + + refs, err = d2lsp.GetFieldRefs("", "index.d2", fs, "okay") + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, "ok.d2", refs[0].AST().GetRange().Path) + + refs, err = d2lsp.GetFieldRefs("", "ok.d2", fs, "hi") + assert.Success(t, err) + assert.Equal(t, 0, len(refs)) + + refs, err = d2lsp.GetFieldRefs("", "ok.d2", fs, "okay") + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) +} diff --git a/lib/memfs/memfs.go b/lib/memfs/memfs.go index 140512b3a..ed594540b 100644 --- a/lib/memfs/memfs.go +++ b/lib/memfs/memfs.go @@ -4,6 +4,7 @@ package memfs import ( "errors" + "io" "io/fs" "path" "path/filepath" @@ -42,14 +43,35 @@ func (mfs *MemoryFS) addFile(p string, content []byte, isDir bool) { } } +type MemoryFileHandle struct { + *MemoryFile + offset int +} + func (mfs *MemoryFS) Open(name string) (fs.File, error) { file, ok := mfs.files[filepath.Clean(name)] if !ok { return nil, fs.ErrNotExist } - return file, nil + return &MemoryFileHandle{MemoryFile: file}, nil } +func (fh *MemoryFileHandle) Stat() (fs.FileInfo, error) { return fh.MemoryFile, nil } + +func (fh *MemoryFileHandle) Read(b []byte) (int, error) { + if fh.isDir { + return 0, errors.New("cannot read a directory") + } + if fh.offset >= len(fh.content) { + return 0, io.EOF + } + n := copy(b, fh.content[fh.offset:]) + fh.offset += n + return n, nil +} + +func (fh *MemoryFileHandle) Close() error { return nil } + func (mf *MemoryFile) Stat() (fs.FileInfo, error) { return mf, nil } func (mf *MemoryFile) Read(b []byte) (int, error) { if mf.isDir {