diff --git a/d2layouts/d2dagrelayout/layout.go b/d2layouts/d2dagrelayout/layout.go index 6d004e25f..b88e5e6b9 100644 --- a/d2layouts/d2dagrelayout/layout.go +++ b/d2layouts/d2dagrelayout/layout.go @@ -8,6 +8,7 @@ import ( "math" "regexp" "sort" + "strconv" "strings" "cdr.dev/slog" @@ -140,17 +141,13 @@ func Layout(ctx context.Context, g *d2graph.Graph, opts *ConfigurableOpts) (err return err } + mapper := NewObjectMapper() loadScript := "" - idToObj := make(map[string]*d2graph.Object) for _, obj := range g.Objects { - id := obj.AbsID() - idToObj[id] = obj - - width, height := obj.Width, obj.Height - - loadScript += generateAddNodeLine(id, int(width), int(height)) + mapper.Register(obj) + loadScript += mapper.generateAddNodeLine(obj, int(obj.Width), int(obj.Height)) if obj.Parent != g.Root { - loadScript += generateAddParentLine(id, obj.Parent.AbsID()) + loadScript += mapper.generateAddParentLine(obj, obj.Parent) } } @@ -178,7 +175,7 @@ func Layout(ctx context.Context, g *d2graph.Graph, opts *ConfigurableOpts) (err } } - loadScript += generateAddEdgeLine(src.AbsID(), dst.AbsID(), edge.AbsID(), width, height) + loadScript += mapper.generateAddEdgeLine(src, dst, edge.AbsID(), width, height) } if debugJS { @@ -209,7 +206,7 @@ func Layout(ctx context.Context, g *d2graph.Graph, opts *ConfigurableOpts) (err log.Debug(ctx, "graph", slog.F("json", dn)) } - obj := idToObj[dn.ID] + obj := mapper.ToObj(dn.ID) // dagre gives center of node obj.TopLeft = geo.NewPoint(math.Round(dn.X-dn.Width/2), math.Round(dn.Y-dn.Height/2)) @@ -415,6 +412,32 @@ func setGraphAttrs(attrs dagreOpts) string { ) } +type objectMapper struct { + objToID map[*d2graph.Object]string + idToObj map[string]*d2graph.Object +} + +func NewObjectMapper() *objectMapper { + return &objectMapper{ + objToID: make(map[*d2graph.Object]string), + idToObj: make(map[string]*d2graph.Object), + } +} + +func (c *objectMapper) Register(obj *d2graph.Object) { + id := strconv.Itoa(len(c.idToObj)) + c.idToObj[id] = obj + c.objToID[obj] = id +} + +func (c *objectMapper) ToID(obj *d2graph.Object) string { + return c.objToID[obj] +} + +func (c *objectMapper) ToObj(id string) *d2graph.Object { + return c.idToObj[id] +} + func escapeID(id string) string { // fixes \\ id = strings.ReplaceAll(id, "\\", `\\`) @@ -426,17 +449,20 @@ func escapeID(id string) string { return id } -func generateAddNodeLine(id string, width, height int) string { - id = escapeID(id) +func (c objectMapper) generateAddNodeLine(obj *d2graph.Object, width, height int) string { + id := c.ToID(obj) return fmt.Sprintf("g.setNode(`%s`, { id: `%s`, width: %d, height: %d });\n", id, id, width, height) } -func generateAddParentLine(childID, parentID string) string { - return fmt.Sprintf("g.setParent(`%s`, `%s`);\n", escapeID(childID), escapeID(parentID)) +func (c objectMapper) generateAddParentLine(child, parent *d2graph.Object) string { + return fmt.Sprintf("g.setParent(`%s`, `%s`);\n", c.ToID(child), c.ToID(parent)) } -func generateAddEdgeLine(fromID, toID, edgeID string, width, height int) string { - return fmt.Sprintf("g.setEdge({v:`%s`, w:`%s`, name:`%s`}, { width:%d, height:%d, labelpos: `c` });\n", escapeID(fromID), escapeID(toID), escapeID(edgeID), width, height) +func (c objectMapper) generateAddEdgeLine(from, to *d2graph.Object, edgeID string, width, height int) string { + return fmt.Sprintf( + "g.setEdge({v:`%s`, w:`%s`, name:`%s`}, { width:%d, height:%d, labelpos: `c` });\n", + c.ToID(from), c.ToID(to), escapeID(edgeID), width, height, + ) } // getLongestEdgeChainHead finds the longest chain in a container and gets its head