From 2245355d1d6f6993061eeae877cea14452505750 Mon Sep 17 00:00:00 2001 From: Alexander Wang Date: Thu, 17 Oct 2024 08:32:54 -0600 Subject: [PATCH] d2lsp: implement GetImportRanges --- d2ast/d2ast.go | 136 +++++++++++++++++++++++++++++++++++++++++++- d2lsp/d2lsp.go | 79 +++++++++++++++++-------- d2lsp/d2lsp_test.go | 55 +++++++++++++----- 3 files changed, 231 insertions(+), 39 deletions(-) diff --git a/d2ast/d2ast.go b/d2ast/d2ast.go index f82da677d..711316b41 100644 --- a/d2ast/d2ast.go +++ b/d2ast/d2ast.go @@ -49,8 +49,7 @@ type Node interface { // GetRange returns the range a node occupies in its file. GetRange() Range - // TODO: add Children() for walking AST - // Children() []Node + Children() []Node } var _ Node = &Comment{} @@ -432,6 +431,139 @@ func (s *DoubleQuotedString) scalar() {} func (s *SingleQuotedString) scalar() {} func (s *BlockString) scalar() {} +func (c *Comment) Children() []Node { return nil } +func (c *BlockComment) Children() []Node { return nil } +func (n *Null) Children() []Node { return nil } +func (b *Boolean) Children() []Node { return nil } +func (n *Number) Children() []Node { return nil } +func (s *SingleQuotedString) Children() []Node { return nil } +func (s *BlockString) Children() []Node { return nil } +func (ei *EdgeIndex) Children() []Node { return nil } + +func (s *UnquotedString) Children() []Node { + var children []Node + for _, box := range s.Value { + if box.Substitution != nil { + children = append(children, box.Substitution) + } + } + return children +} + +func (s *DoubleQuotedString) Children() []Node { + var children []Node + for _, box := range s.Value { + if box.Substitution != nil { + children = append(children, box.Substitution) + } + } + return children +} + +func (s *Substitution) Children() []Node { + var children []Node + for _, sb := range s.Path { + if sb != nil { + if child := sb.Unbox(); child != nil { + children = append(children, child) + } + } + } + return children +} + +func (i *Import) Children() []Node { + var children []Node + for _, sb := range i.Path { + if sb != nil { + if child := sb.Unbox(); child != nil { + children = append(children, child) + } + } + } + return children +} + +func (a *Array) Children() []Node { + var children []Node + for _, box := range a.Nodes { + if child := box.Unbox(); child != nil { + children = append(children, child) + } + } + return children +} + +func (m *Map) Children() []Node { + var children []Node + for _, box := range m.Nodes { + if child := box.Unbox(); child != nil { + children = append(children, child) + } + } + return children +} + +func (k *Key) Children() []Node { + var children []Node + if k.Key != nil { + children = append(children, k.Key) + } + for _, edge := range k.Edges { + if edge != nil { + children = append(children, edge) + } + } + if k.EdgeIndex != nil { + children = append(children, k.EdgeIndex) + } + if k.EdgeKey != nil { + children = append(children, k.EdgeKey) + } + if scalar := k.Primary.Unbox(); scalar != nil { + children = append(children, scalar) + } + if value := k.Value.Unbox(); value != nil { + children = append(children, value) + } + return children +} + +func (kp *KeyPath) Children() []Node { + var children []Node + for _, sb := range kp.Path { + if sb != nil { + if child := sb.Unbox(); child != nil { + children = append(children, child) + } + } + } + return children +} + +func (e *Edge) Children() []Node { + var children []Node + if e.Src != nil { + children = append(children, e.Src) + } + if e.Dst != nil { + children = append(children, e.Dst) + } + return children +} + +func Walk(node Node, fn func(Node) bool) { + if node == nil { + return + } + if !fn(node) { + return + } + for _, child := range node.Children() { + Walk(child, fn) + } +} + // TODO: mistake, move into parse.go func (n *Null) ScalarString() string { return "" } func (b *Boolean) ScalarString() string { return strconv.FormatBool(b.Value) } diff --git a/d2lsp/d2lsp.go b/d2lsp/d2lsp.go index 0605b4709..5a345afae 100644 --- a/d2lsp/d2lsp.go +++ b/d2lsp/d2lsp.go @@ -3,24 +3,17 @@ package d2lsp import ( "fmt" + "path/filepath" "strings" + "oss.terrastruct.com/d2/d2ast" "oss.terrastruct.com/d2/d2ir" "oss.terrastruct.com/d2/d2parser" "oss.terrastruct.com/d2/lib/memfs" ) -func GetRefs(path string, fs map[string]string, key string, boardPath []string) (refs []d2ir.Reference, _ error) { - if _, ok := fs[path]; !ok { - return nil, fmt.Errorf(`"%s" not found`, path) - } - r := strings.NewReader(fs[path]) - ast, err := d2parser.Parse(path, r, nil) - if err != nil { - return nil, err - } - - mfs, err := memfs.New(fs) +func GetRefs(path string, fs map[string]string, boardPath []string, key string) (refs []d2ir.Reference, _ error) { + m, err := getBoardMap(path, fs, boardPath) if err != nil { return nil, err } @@ -33,18 +26,6 @@ func GetRefs(path string, fs map[string]string, key string, boardPath []string) return nil, fmt.Errorf(`"%s" is invalid`, key) } - m, _, err := d2ir.Compile(ast, &d2ir.CompileOptions{ - FS: mfs, - }) - if err != nil { - return nil, err - } - - m = m.FindBoardRoot(boardPath) - if m == nil { - return nil, fmt.Errorf(`board "%v" not found`, boardPath) - } - var f *d2ir.Field if mk.Key != nil { for _, p := range mk.Key.Path { @@ -78,3 +59,55 @@ func GetRefs(path string, fs map[string]string, key string, boardPath []string) } return refs, nil } + +func GetImportRanges(path string, fs map[string]string, importPath string) (ranges []d2ast.Range, _ error) { + if _, ok := fs[path]; !ok { + return nil, fmt.Errorf(`"%s" not found`, path) + } + r := strings.NewReader(fs[path]) + ast, err := d2parser.Parse(path, r, nil) + if err != nil { + return nil, err + } + + d2ast.Walk(ast, func(n d2ast.Node) bool { + switch t := n.(type) { + case *d2ast.Import: + if (filepath.Join(filepath.Dir(path), t.PathWithPre()) + ".d2") == importPath { + ranges = append(ranges, t.Range) + } + } + return true + }) + + return ranges, nil +} + +func getBoardMap(path string, fs map[string]string, boardPath []string) (*d2ir.Map, error) { + if _, ok := fs[path]; !ok { + return nil, fmt.Errorf(`"%s" not found`, path) + } + r := strings.NewReader(fs[path]) + ast, err := d2parser.Parse(path, r, nil) + if err != nil { + return nil, err + } + + mfs, err := memfs.New(fs) + if err != nil { + return nil, err + } + + m, _, err := d2ir.Compile(ast, &d2ir.CompileOptions{ + FS: mfs, + }) + if err != nil { + return nil, err + } + + m = m.FindBoardRoot(boardPath) + if m == nil { + return nil, fmt.Errorf(`board "%v" not found`, boardPath) + } + return m, nil +} diff --git a/d2lsp/d2lsp_test.go b/d2lsp/d2lsp_test.go index 6f965f943..57327711e 100644 --- a/d2lsp/d2lsp_test.go +++ b/d2lsp/d2lsp_test.go @@ -15,14 +15,14 @@ x -> y` fs := map[string]string{ "index.d2": script, } - refs, err := d2lsp.GetRefs("index.d2", fs, "x", nil) + refs, err := d2lsp.GetRefs("index.d2", fs, nil, "x") assert.Success(t, err) assert.Equal(t, 3, len(refs)) assert.Equal(t, 0, refs[0].AST().GetRange().Start.Line) assert.Equal(t, 1, refs[1].AST().GetRange().Start.Line) assert.Equal(t, 3, refs[2].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "a.x", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "a.x") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) @@ -42,26 +42,26 @@ b: { fs := map[string]string{ "index.d2": script, } - refs, err := d2lsp.GetRefs("index.d2", fs, "x -> y", nil) + refs, err := d2lsp.GetRefs("index.d2", fs, nil, "x -> y") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 3, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "y -> z", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "y -> z") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 4, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "x -> z", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "x -> z") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 5, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "a -> b", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "a -> b") assert.Success(t, err) assert.Equal(t, 0, len(refs)) - refs, err = d2lsp.GetRefs("index.d2", fs, "b.(x -> y)", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "b.(x -> y)") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 7, refs[0].AST().GetRange().Start.Line) @@ -77,21 +77,21 @@ hi okay `, } - refs, err := d2lsp.GetRefs("index.d2", fs, "hi", nil) + refs, err := d2lsp.GetRefs("index.d2", fs, nil, "hi") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 2, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "okay", nil) + refs, err = d2lsp.GetRefs("index.d2", fs, nil, "okay") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, "ok.d2", refs[0].AST().GetRange().Path) - refs, err = d2lsp.GetRefs("ok.d2", fs, "hi", nil) + refs, err = d2lsp.GetRefs("ok.d2", fs, nil, "hi") assert.Success(t, err) assert.Equal(t, 0, len(refs)) - refs, err = d2lsp.GetRefs("ok.d2", fs, "okay", nil) + refs, err = d2lsp.GetRefs("ok.d2", fs, nil, "okay") assert.Success(t, err) assert.Equal(t, 1, len(refs)) } @@ -107,15 +107,42 @@ layers: { } `, } - refs, err := d2lsp.GetRefs("index.d2", fs, "hello", []string{"x"}) + refs, err := d2lsp.GetRefs("index.d2", fs, []string{"x"}, "hello") assert.Success(t, err) assert.Equal(t, 1, len(refs)) assert.Equal(t, 4, refs[0].AST().GetRange().Start.Line) - refs, err = d2lsp.GetRefs("index.d2", fs, "hi", []string{"x"}) + refs, err = d2lsp.GetRefs("index.d2", fs, []string{"x"}, "hi") assert.Success(t, err) assert.Equal(t, 0, len(refs)) - _, err = d2lsp.GetRefs("index.d2", fs, "hello", []string{"y"}) + _, err = d2lsp.GetRefs("index.d2", fs, []string{"y"}, "hello") assert.Equal(t, `board "[y]" not found`, err.Error()) } + +func TestGetImportRanges(t *testing.T) { + fs := map[string]string{ + "yes/index.d2": ` +...@../fast/ok +hi +hey: { + ...@pok +} +`, + "fast/ok.d2": ` +okay +`, + "yes/pok.d2": ` +des +`, + } + ranges, err := d2lsp.GetImportRanges("yes/index.d2", fs, "fast/ok.d2") + assert.Success(t, err) + assert.Equal(t, 1, len(ranges)) + assert.Equal(t, 1, ranges[0].Start.Line) + + ranges, err = d2lsp.GetImportRanges("yes/index.d2", fs, "yes/pok.d2") + assert.Success(t, err) + assert.Equal(t, 1, len(ranges)) + assert.Equal(t, 4, ranges[0].Start.Line) +}