diff --git a/d2graph/serde.go b/d2graph/serde.go index 88c3abd67..6b1006c31 100644 --- a/d2graph/serde.go +++ b/d2graph/serde.go @@ -28,7 +28,7 @@ func DeserializeGraph(bytes []byte, g *Graph) error { } var root Object - convert(sg.Root, &root) + Convert(sg.Root, &root) g.Root = &root root.Graph = g g.RootLevel = sg.RootLevel @@ -38,7 +38,7 @@ func DeserializeGraph(bytes []byte, g *Graph) error { var objects []*Object for _, so := range sg.Objects { var o Object - if err := convert(so, &o); err != nil { + if err := Convert(so, &o); err != nil { return err } o.Graph = g @@ -67,7 +67,7 @@ func DeserializeGraph(bytes []byte, g *Graph) error { var edges []*Edge for _, se := range sg.Edges { var e Edge - if err := convert(se, &e); err != nil { + if err := Convert(se, &e); err != nil { return err } @@ -108,7 +108,7 @@ func SerializeGraph(g *Graph) ([]byte, error) { var sedges []SerializedEdge for _, e := range g.Edges { - se, err := toSerializedEdge(e) + se, err := ToSerializedEdge(e) if err != nil { return nil, err } @@ -121,7 +121,7 @@ func SerializeGraph(g *Graph) ([]byte, error) { func toSerializedObject(o *Object) (SerializedObject, error) { var so SerializedObject - if err := convert(o, &so); err != nil { + if err := Convert(o, &so); err != nil { return nil, err } @@ -138,9 +138,9 @@ func toSerializedObject(o *Object) (SerializedObject, error) { return so, nil } -func toSerializedEdge(e *Edge) (SerializedEdge, error) { +func ToSerializedEdge(e *Edge) (SerializedEdge, error) { var se SerializedEdge - if err := convert(e, &se); err != nil { + if err := Convert(e, &se); err != nil { return nil, err } @@ -154,7 +154,7 @@ func toSerializedEdge(e *Edge) (SerializedEdge, error) { return se, nil } -func convert[T, Q any](from T, to *Q) error { +func Convert[T, Q any](from T, to *Q) error { b, err := json.Marshal(from) if err != nil { return err diff --git a/d2plugin/exec.go b/d2plugin/exec.go index 4631ef036..a6776feed 100644 --- a/d2plugin/exec.go +++ b/d2plugin/exec.go @@ -201,3 +201,59 @@ func (p *execPlugin) PostProcess(ctx context.Context, in []byte) ([]byte, error) return stdout, nil } + +func (p *execPlugin) RouteEdges(ctx context.Context, g *d2graph.Graph, edges []*d2graph.Edge) error { + ctx, cancel := timelib.WithTimeout(ctx, time.Minute*2) + defer cancel() + + graphBytes, err := d2graph.SerializeGraph(g) + if err != nil { + return err + } + + var g2 d2graph.Graph + err = d2graph.DeserializeGraph(graphBytes, &g2) + if err != nil { + return fmt.Errorf("failed to unmarshal json: %w", err) + } + g2.Edges = edges + graphBytes2, err := d2graph.SerializeGraph(&g2) + if err != nil { + return err + } + + in := routeEdgesInput{ + g: graphBytes, + gedges: graphBytes2, + } + + b, err := json.Marshal(in) + if err != nil { + return err + } + + args := []string{"routeedges"} + for k, v := range p.opts { + args = append(args, fmt.Sprintf("--%s", k), v) + } + cmd := exec.CommandContext(ctx, p.path, args...) + + buffer := bytes.Buffer{} + buffer.Write(b) + cmd.Stdin = &buffer + + stdout, err := cmd.Output() + if err != nil { + ee := &exec.ExitError{} + if errors.As(err, &ee) && len(ee.Stderr) > 0 { + return fmt.Errorf("%v\nstderr:\n%s", ee, ee.Stderr) + } + return err + } + err = d2graph.DeserializeGraph(stdout, g) + if err != nil { + return fmt.Errorf("failed to unmarshal json: %w", err) + } + + return nil +} diff --git a/d2plugin/plugin.go b/d2plugin/plugin.go index f12868faf..070706219 100644 --- a/d2plugin/plugin.go +++ b/d2plugin/plugin.go @@ -85,6 +85,11 @@ type RoutingPlugin interface { RouteEdges(context.Context, *d2graph.Graph, []*d2graph.Edge) error } +type routeEdgesInput struct { + g []byte + gedges []byte +} + // PluginInfo is the current info information of a plugin. // note: The two fields Type and Path are not set by the plugin // itself but only in ListPlugins. diff --git a/d2plugin/serve.go b/d2plugin/serve.go index 9cae5ca2f..b7a13417d 100644 --- a/d2plugin/serve.go +++ b/d2plugin/serve.go @@ -58,6 +58,12 @@ func Serve(p Plugin) xmain.RunFunc { return layout(ctx, p, ms) case "postprocess": return postProcess(ctx, p, ms) + case "routeedges": + routingPlugin, ok := p.(RoutingPlugin) + if !ok { + return fmt.Errorf("plugin has routing feature but does not implement RoutingPlugin") + } + return routeEdges(ctx, routingPlugin, ms) default: return xmain.UsageErrorf("unrecognized command: %s", subcmd) } @@ -137,3 +143,39 @@ func postProcess(ctx context.Context, p Plugin, ms *xmain.State) error { } return nil } + +func routeEdges(ctx context.Context, p RoutingPlugin, ms *xmain.State) error { + inRaw, err := io.ReadAll(ms.Stdin) + if err != nil { + return err + } + + in := routeEdgesInput{} + + err = json.Unmarshal(inRaw, &in) + + var g d2graph.Graph + if err := d2graph.DeserializeGraph(in.g, &g); err != nil { + return fmt.Errorf("failed to unmarshal input graph to graph: %s", in) + } + + var gedges d2graph.Graph + if err := d2graph.DeserializeGraph(in.gedges, &gedges); err != nil { + return fmt.Errorf("failed to unmarshal input edges graph to graph: %s", in) + } + + err = p.RouteEdges(ctx, &g, gedges.Edges) + if err != nil { + return err + } + + b, err := d2graph.SerializeGraph(&g) + if err != nil { + return err + } + _, err = ms.Stdout.Write(b) + if err != nil { + return err + } + return nil +}