diff --git a/d2renderers/d2sketch/sketch.go b/d2renderers/d2sketch/sketch.go index 8425b2fa3..2ce7594d7 100644 --- a/d2renderers/d2sketch/sketch.go +++ b/d2renderers/d2sketch/sketch.go @@ -595,7 +595,7 @@ func ArrowheadJS(r *Runner, arrowhead d2target.Arrowhead, stroke string, strokeW return } -func Arrowheads(r *Runner, connection d2target.Connection) (string, error) { +func Arrowheads(r *Runner, connection d2target.Connection, srcAdj, dstAdj *geo.Point) (string, error) { arrowPaths := []string{} if connection.SrcArrow != d2target.NoArrowhead { @@ -608,13 +608,8 @@ func Arrowheads(r *Runner, connection d2target.Connection) (string, error) { startingVector := startingSegment.ToVector().Reverse() angle := startingVector.Degrees() - // TODO get src shape stroke width - srcStrokeWidth := 2 - distance := float64(connection.StrokeWidth) + (float64(connection.StrokeWidth)+float64(srcStrokeWidth))/2.0 - - sourceAdjustment := startingVector.Unit().Multiply(-distance).ToPoint() transform := fmt.Sprintf(`transform="translate(%f %f) rotate(%v)"`, - startingSegment.Start.X+sourceAdjustment.X, startingSegment.Start.Y+sourceAdjustment.Y, angle, + startingSegment.Start.X+srcAdj.X, startingSegment.Start.Y+srcAdj.Y, angle, ) roughPaths, err := computeRoughPaths(r, arrowJS) @@ -650,16 +645,8 @@ func Arrowheads(r *Runner, connection d2target.Connection) (string, error) { endingVector := endingSegment.ToVector() angle := endingVector.Degrees() - // TODO get dst shape stroke width - dstStrokeWidth := 2 - distance := (float64(connection.StrokeWidth) + float64(dstStrokeWidth)) / 2.0 - if connection.DstArrow != d2target.NoArrowhead { - distance += float64(connection.StrokeWidth) - } - - targetAdjustment := endingVector.Unit().Multiply(-distance).ToPoint() transform := fmt.Sprintf(`transform="translate(%f %f) rotate(%v)"`, - endingSegment.End.X+targetAdjustment.X, endingSegment.End.Y+targetAdjustment.Y, angle, + endingSegment.End.X+dstAdj.X, endingSegment.End.Y+dstAdj.Y, angle, ) roughPaths, err := computeRoughPaths(r, arrowJS) diff --git a/d2renderers/d2svg/d2svg.go b/d2renderers/d2svg/d2svg.go index 876a5b6fa..8082c033c 100644 --- a/d2renderers/d2svg/d2svg.go +++ b/d2renderers/d2svg/d2svg.go @@ -307,17 +307,25 @@ func arrowheadAdjustment(start, end *geo.Point, arrowhead d2target.Arrowhead, ed return v.Unit().Multiply(-distance).ToPoint() } -// returns the path's d attribute for the given connection -func pathData(connection d2target.Connection, idToShape map[string]d2target.Shape) string { - var path []string +func getArrowheadAdjustments(connection d2target.Connection, idToShape map[string]d2target.Shape) (srcAdj, dstAdj *geo.Point) { route := connection.Route srcShape := idToShape[connection.Src] dstShape := idToShape[connection.Dst] - sourceAdjustment := arrowheadAdjustment(route[0], route[1], connection.SrcArrow, connection.StrokeWidth, srcShape.StrokeWidth) + sourceAdjustment := arrowheadAdjustment(route[1], route[0], connection.SrcArrow, connection.StrokeWidth, srcShape.StrokeWidth) + + targetAdjustment := arrowheadAdjustment(route[len(route)-2], route[len(route)-1], connection.DstArrow, connection.StrokeWidth, dstShape.StrokeWidth) + return sourceAdjustment, targetAdjustment +} + +// returns the path's d attribute for the given connection +func pathData(connection d2target.Connection, srcAdj, dstAdj *geo.Point) string { + var path []string + route := connection.Route + path = append(path, fmt.Sprintf("M %f %f", - route[0].X-sourceAdjustment.X, - route[0].Y-sourceAdjustment.Y, + route[0].X+srcAdj.X, + route[0].Y+srcAdj.Y, )) if connection.IsCurve { @@ -330,12 +338,11 @@ func pathData(connection d2target.Connection, idToShape map[string]d2target.Shap )) } // final curve target adjustment - targetAdjustment := arrowheadAdjustment(route[i+1], route[i+2], connection.DstArrow, connection.StrokeWidth, dstShape.StrokeWidth) path = append(path, fmt.Sprintf("C %f %f %f %f %f %f", route[i].X, route[i].Y, route[i+1].X, route[i+1].Y, - route[i+2].X+targetAdjustment.X, - route[i+2].Y+targetAdjustment.Y, + route[i+2].X+dstAdj.X, + route[i+2].Y+dstAdj.Y, )) } else { for i := 1; i < len(route)-1; i++ { @@ -387,12 +394,9 @@ func pathData(connection d2target.Connection, idToShape map[string]d2target.Shap } lastPoint := route[len(route)-1] - secondToLastPoint := route[len(route)-2] - - targetAdjustment := arrowheadAdjustment(secondToLastPoint, lastPoint, connection.DstArrow, connection.StrokeWidth, dstShape.StrokeWidth) path = append(path, fmt.Sprintf("L %f %f", - lastPoint.X+targetAdjustment.X, - lastPoint.Y+targetAdjustment.Y, + lastPoint.X+dstAdj.X, + lastPoint.Y+dstAdj.Y, )) } @@ -448,7 +452,8 @@ func drawConnection(writer io.Writer, labelMaskID string, connection d2target.Co } } - path := pathData(connection, idToShape) + srcAdj, dstAdj := getArrowheadAdjustments(connection, idToShape) + path := pathData(connection, srcAdj, dstAdj) mask := fmt.Sprintf(`mask="url(#%s)"`, labelMaskID) if sketchRunner != nil { out, err := d2sketch.Connection(sketchRunner, connection, path, mask) @@ -458,7 +463,7 @@ func drawConnection(writer io.Writer, labelMaskID string, connection d2target.Co fmt.Fprint(writer, out) // render sketch arrowheads separately - arrowPaths, err := d2sketch.Arrowheads(sketchRunner, connection) + arrowPaths, err := d2sketch.Arrowheads(sketchRunner, connection, srcAdj, dstAdj) if err != nil { return "", err }