diff --git a/d2ir/d2ir.go b/d2ir/d2ir.go index 34ec4ae46..a55b1dc3d 100644 --- a/d2ir/d2ir.go +++ b/d2ir/d2ir.go @@ -379,7 +379,7 @@ func (eid *EdgeID) Match(eid2 *EdgeID) bool { return true } -func (eid *EdgeID) resolveUnderscores(m *Map) (*EdgeID, *Map, error) { +func (eid *EdgeID) resolveToCommon(m *Map) (*d2ast.KeyPath, *EdgeID, *Map, error) { eid = eid.Copy() maxUnderscores := go2.Max(countUnderscores(eid.SrcPath), countUnderscores(eid.DstPath)) for i := 0; i < maxUnderscores; i++ { @@ -397,23 +397,21 @@ func (eid *EdgeID) resolveUnderscores(m *Map) (*EdgeID, *Map, error) { } m = ParentMap(m) if m == nil { - return nil, nil, errors.New("invalid underscore") + return nil, nil, nil, errors.New("invalid underscore") } } - return eid, m, nil -} -func (eid *EdgeID) trimCommon() (common []string, _ *EdgeID) { - eid = eid.Copy() + common := &d2ast.KeyPath{} for len(eid.SrcPath) > 1 && len(eid.DstPath) > 1 { if !strings.EqualFold(eid.SrcPath[0], eid.DstPath[0]) { - return common, eid + return common, eid, m, nil } - common = append(common, eid.SrcPath[0]) + common.Path = append(common.Path, d2ast.MakeValueBox(d2ast.RawString(eid.SrcPath[0], true)).StringBox()) eid.SrcPath = eid.SrcPath[1:] eid.DstPath = eid.DstPath[1:] } - return common, eid + + return common, eid, m, nil } type Edge struct { @@ -732,13 +730,12 @@ func (m *Map) DeleteField(ida ...string) *Field { } func (m *Map) GetEdges(eid *EdgeID) []*Edge { - eid, m, err := eid.resolveUnderscores(m) + common, eid, m, err := eid.resolveToCommon(m) if err != nil { return nil } - common, eid := eid.trimCommon() - if len(common) > 0 { - f := m.GetField(common...) + if len(common.Path) > 0 { + f := m.GetField(common.IDA()...) if f == nil { return nil } @@ -762,17 +759,12 @@ func (m *Map) CreateEdge(eid *EdgeID, refctx *RefContext) (*Edge, error) { return nil, d2parser.Errorf(refctx.Edge, "cannot create edge inside edge") } - eid, m, err := eid.resolveUnderscores(m) + common, eid, m, err := eid.resolveToCommon(m) if err != nil { return nil, d2parser.Errorf(refctx.Edge, err.Error()) } - common, eid := eid.trimCommon() - if len(common) > 0 { - tmp := *refctx.Edge.Src - kp := &tmp - underscores := countUnderscores(kp.IDA()) - kp.Path = kp.Path[underscores : len(common)+underscores] - f, err := m.EnsureField(kp, nil) + if len(common.Path) > 0 { + f, err := m.EnsureField(common, nil) if err != nil { return nil, err }