diff --git a/d2lsp/d2lsp.go b/d2lsp/d2lsp.go index 818e5dd13..0605b4709 100644 --- a/d2lsp/d2lsp.go +++ b/d2lsp/d2lsp.go @@ -10,7 +10,7 @@ import ( "oss.terrastruct.com/d2/lib/memfs" ) -func GetFieldRefs(path string, fs map[string]string, key string, boardPath []string) (refs []d2ir.Reference, _ error) { +func GetRefs(path string, fs map[string]string, key string, boardPath []string) (refs []d2ir.Reference, _ error) { if _, ok := fs[path]; !ok { return nil, fmt.Errorf(`"%s" not found`, path) } @@ -29,33 +29,52 @@ func GetFieldRefs(path string, fs map[string]string, key string, boardPath []str if err != nil { return nil, err } - if mk.Key == nil { + if mk.Key == nil && len(mk.Edges) == 0 { return nil, fmt.Errorf(`"%s" is invalid`, key) } - ir, _, err := d2ir.Compile(ast, &d2ir.CompileOptions{ + m, _, err := d2ir.Compile(ast, &d2ir.CompileOptions{ FS: mfs, }) if err != nil { return nil, err } - ir = ir.FindBoardRoot(boardPath) - if ir == nil { + m = m.FindBoardRoot(boardPath) + if m == nil { return nil, fmt.Errorf(`board "%v" not found`, boardPath) } var f *d2ir.Field - curr := ir - for _, p := range mk.Key.Path { - f = curr.GetField(p.Unbox().ScalarString()) - if f == nil { + if mk.Key != nil { + for _, p := range mk.Key.Path { + f = m.GetField(p.Unbox().ScalarString()) + if f == nil { + return nil, nil + } + m = f.Map() + } + } + + if len(mk.Edges) > 0 { + eids := d2ir.NewEdgeIDs(mk) + var edges []*d2ir.Edge + for _, eid := range eids { + edges = append(edges, m.GetEdges(eid, nil, nil)...) + } + if len(edges) == 0 { return nil, nil } - curr = f.Map() - } - for _, ref := range f.References { - refs = append(refs, ref) + for _, edge := range edges { + for _, ref := range edge.References { + refs = append(refs, ref) + } + } + return refs, nil + } else { + for _, ref := range f.References { + refs = append(refs, ref) + } } return refs, nil } diff --git a/d2lsp/d2lsp_test.go b/d2lsp/d2lsp_test.go index 7b915b190..6f965f943 100644 --- a/d2lsp/d2lsp_test.go +++ b/d2lsp/d2lsp_test.go @@ -7,7 +7,7 @@ import ( "oss.terrastruct.com/util-go/assert" ) -func TestGetRefs(t *testing.T) { +func TestGetFieldRefs(t *testing.T) { script := `x x.a a.x @@ -15,19 +15,58 @@ x -> y` fs := map[string]string{ "index.d2": script, } - refs, err := d2lsp.GetFieldRefs("index.d2", fs, "x", nil) + refs, err := d2lsp.GetRefs("index.d2", fs, "x", nil) assert.Success(t, err) assert.Equal(t, 3, len(refs)) assert.Equal(t, 0, refs[0].AST().GetRange().Start.Line) assert.Equal(t, 1, refs[1].AST().GetRange().Start.Line) assert.Equal(t, 3, refs[2].AST().GetRange().Start.Line) - refs, err = d2lsp.GetFieldRefs("index.d2", fs, "a.x", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, "a.x", nil) assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) } +func TestGetEdgeRefs(t *testing.T) { + script := `x +x.a +a.x +x -> y +y -> z +x -> z +b: { + x -> y +} +` + fs := map[string]string{ + "index.d2": script, + } + refs, err := d2lsp.GetRefs("index.d2", fs, "x -> y", nil) + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, 3, refs[0].AST().GetRange().Start.Line) + + refs, err = d2lsp.GetRefs("index.d2", fs, "y -> z", nil) + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, 4, refs[0].AST().GetRange().Start.Line) + + refs, err = d2lsp.GetRefs("index.d2", fs, "x -> z", nil) + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, 5, refs[0].AST().GetRange().Start.Line) + + refs, err = d2lsp.GetRefs("index.d2", fs, "a -> b", nil) + assert.Success(t, err) + assert.Equal(t, 0, len(refs)) + + refs, err = d2lsp.GetRefs("index.d2", fs, "b.(x -> y)", nil) + assert.Success(t, err) + assert.Equal(t, 1, len(refs)) + assert.Equal(t, 7, refs[0].AST().GetRange().Start.Line) +} + func TestGetRefsImported(t *testing.T) { fs := map[string]string{ "index.d2": ` @@ -38,21 +77,21 @@ hi okay `, } - refs, err := d2lsp.GetFieldRefs("index.d2", fs, "hi", nil) + refs, err := d2lsp.GetRefs("index.d2", fs, "hi", nil) assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetFieldRefs("index.d2", fs, "okay", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, "okay", nil) assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, "ok.d2", refs[0].AST().GetRange().Path) - refs, err = d2lsp.GetFieldRefs("ok.d2", fs, "hi", nil) + refs, err = d2lsp.GetRefs("ok.d2", fs, "hi", nil) assert.Success(t, err) assert.Equal(t, 0, len(refs)) - refs, err = d2lsp.GetFieldRefs("ok.d2", fs, "okay", nil) + refs, err = d2lsp.GetRefs("ok.d2", fs, "okay", nil) assert.Success(t, err) assert.Equal(t, 1, len(refs)) } @@ -68,15 +107,15 @@ layers: { } `, } - refs, err := d2lsp.GetFieldRefs("index.d2", fs, "hello", []string{"x"}) + refs, err := d2lsp.GetRefs("index.d2", fs, "hello", []string{"x"}) assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 4, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetFieldRefs("index.d2", fs, "hi", []string{"x"}) + refs, err = d2lsp.GetRefs("index.d2", fs, "hi", []string{"x"}) assert.Success(t, err) assert.Equal(t, 0, len(refs)) - _, err = d2lsp.GetFieldRefs("index.d2", fs, "hello", []string{"y"}) + _, err = d2lsp.GetRefs("index.d2", fs, "hello", []string{"y"}) assert.Equal(t, `board "[y]" not found`, err.Error()) }