This commit is contained in:
Mayank Mohapatra 2025-03-08 11:20:18 +00:00
parent fd8d01dfe6
commit f37193a2c6

View file

@ -1,337 +1,337 @@
package d2cycle package d2cycle
import ( import (
"context" "context"
"math" "math"
"sort" "sort"
"oss.terrastruct.com/d2/d2graph" "oss.terrastruct.com/d2/d2graph"
"oss.terrastruct.com/d2/lib/geo" "oss.terrastruct.com/d2/lib/geo"
"oss.terrastruct.com/d2/lib/label" "oss.terrastruct.com/d2/lib/label"
"oss.terrastruct.com/util-go/go2" "oss.terrastruct.com/util-go/go2"
) )
const ( const (
MIN_RADIUS = 200 MIN_RADIUS = 200
PADDING = 20 PADDING = 20
MIN_SEGMENT_LEN = 10 MIN_SEGMENT_LEN = 10
ARC_STEPS = 100 ARC_STEPS = 100
) )
// Layout lays out the graph and computes curved edge routes. // Layout lays out the graph and computes curved edge routes.
func Layout(ctx context.Context, g *d2graph.Graph, layout d2graph.LayoutGraph) error { func Layout(ctx context.Context, g *d2graph.Graph, layout d2graph.LayoutGraph) error {
objects := g.Root.ChildrenArray objects := g.Root.ChildrenArray
if len(objects) == 0 { if len(objects) == 0 {
return nil return nil
} }
for _, obj := range g.Objects { for _, obj := range g.Objects {
positionLabelsIcons(obj) positionLabelsIcons(obj)
} }
radius := calculateRadius(objects) radius := calculateRadius(objects)
positionObjects(objects, radius) positionObjects(objects, radius)
for _, edge := range g.Edges { for _, edge := range g.Edges {
createCircularArc(edge) createCircularArc(edge)
} }
return nil return nil
} }
// calculateRadius computes a radius ensuring that the circular layout does not overlap. // calculateRadius computes a radius ensuring that the circular layout does not overlap.
// For each object we compute the half-diagonal (i.e. the radius of the minimal enclosing circle), // For each object we compute the half-diagonal (i.e. the radius of the minimal enclosing circle),
// then ensure the chord between two adjacent centers (2*radius*sin(π/n)) is at least // then ensure the chord between two adjacent centers (2*radius*sin(π/n)) is at least
// 2*(maxHalfDiagonal + PADDING). We also add a safety factor (1.2) to avoid floating-point issues. // 2*(maxHalfDiag + PADDING). We also add a safety factor (1.2) to avoid floating-point issues.
func calculateRadius(objects []*d2graph.Object) float64 { func calculateRadius(objects []*d2graph.Object) float64 {
if len(objects) < 2 { if len(objects) < 2 {
return MIN_RADIUS return MIN_RADIUS
} }
numObjects := float64(len(objects)) numObjects := float64(len(objects))
maxHalfDiag := 0.0 maxHalfDiag := 0.0
for _, obj := range objects { for _, obj := range objects {
halfDiag := math.Hypot(obj.Box.Width/2, obj.Box.Height/2) halfDiag := math.Hypot(obj.Box.Width/2, obj.Box.Height/2)
if halfDiag > maxHalfDiag { if halfDiag > maxHalfDiag {
maxHalfDiag = halfDiag maxHalfDiag = halfDiag
} }
} }
// We need the chord (distance between adjacent centers) to be at least: // We need the chord (distance between adjacent centers) to be at least:
// 2*(maxHalfDiag + PADDING) // 2*(maxHalfDiag + PADDING)
// and since chord = 2*radius*sin(π/n), we require: // and since chord = 2*radius*sin(π/n), we require:
// radius >= (maxHalfDiag + PADDING) / sin(π/n) // radius >= (maxHalfDiag + PADDING) / sin(π/n)
minRadius := (maxHalfDiag + PADDING) / math.Sin(math.Pi/numObjects) minRadius := (maxHalfDiag + PADDING) / math.Sin(math.Pi/numObjects)
// Apply a safety factor of 1.2 and ensure it doesn't fall below MIN_RADIUS. // Apply a safety factor of 1.2 and ensure it doesn't fall below MIN_RADIUS.
return math.Max(minRadius*1.2, MIN_RADIUS) return math.Max(minRadius*1.2, MIN_RADIUS)
} }
func positionObjects(objects []*d2graph.Object, radius float64) { func positionObjects(objects []*d2graph.Object, radius float64) {
numObjects := float64(len(objects)) numObjects := float64(len(objects))
angleOffset := -math.Pi / 2 angleOffset := -math.Pi / 2
for i, obj := range objects { for i, obj := range objects {
angle := angleOffset + (2*math.Pi*float64(i)/numObjects) angle := angleOffset + (2*math.Pi*float64(i)/numObjects)
x := radius * math.Cos(angle) x := radius * math.Cos(angle)
y := radius * math.Sin(angle) y := radius * math.Sin(angle)
obj.TopLeft = geo.NewPoint( obj.TopLeft = geo.NewPoint(
x-obj.Box.Width/2, x-obj.Box.Width/2,
y-obj.Box.Height/2, y-obj.Box.Height/2,
) )
} }
} }
func createCircularArc(edge *d2graph.Edge) { func createCircularArc(edge *d2graph.Edge) {
if edge.Src == nil || edge.Dst == nil { if edge.Src == nil || edge.Dst == nil {
return return
} }
srcCenter := edge.Src.Center() srcCenter := edge.Src.Center()
dstCenter := edge.Dst.Center() dstCenter := edge.Dst.Center()
srcAngle := math.Atan2(srcCenter.Y, srcCenter.X) srcAngle := math.Atan2(srcCenter.Y, srcCenter.X)
dstAngle := math.Atan2(dstCenter.Y, dstCenter.X) dstAngle := math.Atan2(dstCenter.Y, dstCenter.X)
if dstAngle < srcAngle { if dstAngle < srcAngle {
dstAngle += 2 * math.Pi dstAngle += 2 * math.Pi
} }
arcRadius := math.Hypot(srcCenter.X, srcCenter.Y) arcRadius := math.Hypot(srcCenter.X, srcCenter.Y)
path := make([]*geo.Point, 0, ARC_STEPS+1) path := make([]*geo.Point, 0, ARC_STEPS+1)
for i := 0; i <= ARC_STEPS; i++ { for i := 0; i <= ARC_STEPS; i++ {
t := float64(i) / float64(ARC_STEPS) t := float64(i) / float64(ARC_STEPS)
angle := srcAngle + t*(dstAngle-srcAngle) angle := srcAngle + t*(dstAngle-srcAngle)
x := arcRadius * math.Cos(angle) x := arcRadius * math.Cos(angle)
y := arcRadius * math.Sin(angle) y := arcRadius * math.Sin(angle)
path = append(path, geo.NewPoint(x, y)) path = append(path, geo.NewPoint(x, y))
} }
path[0] = srcCenter path[0] = srcCenter
path[len(path)-1] = dstCenter path[len(path)-1] = dstCenter
// Clamp endpoints to the boundaries of the source and destination boxes. // Clamp endpoints to the boundaries of the source and destination boxes.
_, newSrc := clampPointOutsideBox(edge.Src.Box, path, 0) _, newSrc := clampPointOutsideBox(edge.Src.Box, path, 0)
_, newDst := clampPointOutsideBoxReverse(edge.Dst.Box, path, len(path)-1) _, newDst := clampPointOutsideBoxReverse(edge.Dst.Box, path, len(path)-1)
path[0] = newSrc path[0] = newSrc
path[len(path)-1] = newDst path[len(path)-1] = newDst
// Trim redundant path points that fall inside node boundaries. // Trim redundant path points that fall inside node boundaries.
path = trimPathPoints(path, edge.Src.Box) path = trimPathPoints(path, edge.Src.Box)
path = trimPathPoints(path, edge.Dst.Box) path = trimPathPoints(path, edge.Dst.Box)
edge.Route = path edge.Route = path
edge.IsCurve = true edge.IsCurve = true
if len(edge.Route) >= 2 { if len(edge.Route) >= 2 {
lastIndex := len(edge.Route) - 1 lastIndex := len(edge.Route) - 1
lastPoint := edge.Route[lastIndex] lastPoint := edge.Route[lastIndex]
secondLastPoint := edge.Route[lastIndex-1] secondLastPoint := edge.Route[lastIndex-1]
tangentX := -lastPoint.Y tangentX := -lastPoint.Y
tangentY := lastPoint.X tangentY := lastPoint.X
mag := math.Hypot(tangentX, tangentY) mag := math.Hypot(tangentX, tangentY)
if mag > 0 { if mag > 0 {
tangentX /= mag tangentX /= mag
tangentY /= mag tangentY /= mag
} }
const MIN_SEGMENT_LEN = 4.159 const MIN_SEGMENT_LEN = 4.159
dx := lastPoint.X - secondLastPoint.X dx := lastPoint.X - secondLastPoint.X
dy := lastPoint.Y - secondLastPoint.Y dy := lastPoint.Y - secondLastPoint.Y
segLength := math.Hypot(dx, dy) segLength := math.Hypot(dx, dy)
if segLength > 0 { if segLength > 0 {
currentDirX := dx / segLength currentDirX := dx / segLength
currentDirY := dy / segLength currentDirY := dy / segLength
// Check if we need to adjust the direction // Check if we need to adjust the direction
if segLength < MIN_SEGMENT_LEN || (currentDirX*tangentX+currentDirY*tangentY) < 0.999 { if segLength < MIN_SEGMENT_LEN || (currentDirX*tangentX+currentDirY*tangentY) < 0.999 {
adjustLength := MIN_SEGMENT_LEN adjustLength := MIN_SEGMENT_LEN
if segLength >= MIN_SEGMENT_LEN { if segLength >= MIN_SEGMENT_LEN {
adjustLength = segLength adjustLength = segLength
} }
newSecondLastX := lastPoint.X - tangentX*adjustLength newSecondLastX := lastPoint.X - tangentX*adjustLength
newSecondLastY := lastPoint.Y - tangentY*adjustLength newSecondLastY := lastPoint.Y - tangentY*adjustLength
edge.Route[lastIndex-1] = geo.NewPoint(newSecondLastX, newSecondLastY) edge.Route[lastIndex-1] = geo.NewPoint(newSecondLastX, newSecondLastY)
} }
} }
} }
} }
// clampPointOutsideBox walks forward along the path until it finds a point outside the box, // clampPointOutsideBox walks forward along the path until it finds a point outside the box,
// then replaces the point with a precise intersection. // then replaces the point with a precise intersection.
func clampPointOutsideBox(box *geo.Box, path []*geo.Point, startIdx int) (int, *geo.Point) { func clampPointOutsideBox(box *geo.Box, path []*geo.Point, startIdx int) (int, *geo.Point) {
if startIdx >= len(path)-1 { if startIdx >= len(path)-1 {
return startIdx, path[startIdx] return startIdx, path[startIdx]
} }
if !boxContains(box, path[startIdx]) { if !boxContains(box, path[startIdx]) {
return startIdx, path[startIdx] return startIdx, path[startIdx]
} }
for i := startIdx + 1; i < len(path); i++ { for i := startIdx + 1; i < len(path); i++ {
if boxContains(box, path[i]) { if boxContains(box, path[i]) {
continue continue
} }
seg := geo.NewSegment(path[i-1], path[i]) seg := geo.NewSegment(path[i-1], path[i])
inter := findPreciseIntersection(box, *seg) inter := findPreciseIntersection(box, *seg)
if inter != nil { if inter != nil {
return i, inter return i, inter
} }
return i, path[i] return i, path[i]
} }
return len(path)-1, path[len(path)-1] return len(path)-1, path[len(path)-1]
} }
// clampPointOutsideBoxReverse works similarly but in reverse order. // clampPointOutsideBoxReverse works similarly but in reverse order.
func clampPointOutsideBoxReverse(box *geo.Box, path []*geo.Point, endIdx int) (int, *geo.Point) { func clampPointOutsideBoxReverse(box *geo.Box, path []*geo.Point, endIdx int) (int, *geo.Point) {
if endIdx <= 0 { if endIdx <= 0 {
return endIdx, path[endIdx] return endIdx, path[endIdx]
} }
if !boxContains(box, path[endIdx]) { if !boxContains(box, path[endIdx]) {
return endIdx, path[endIdx] return endIdx, path[endIdx]
} }
for j := endIdx - 1; j >= 0; j-- { for j := endIdx - 1; j >= 0; j-- {
if boxContains(box, path[j]) { if boxContains(box, path[j]) {
continue continue
} }
seg := geo.NewSegment(path[j], path[j+1]) seg := geo.NewSegment(path[j], path[j+1])
inter := findPreciseIntersection(box, *seg) inter := findPreciseIntersection(box, *seg)
if inter != nil { if inter != nil {
return j, inter return j, inter
} }
return j, path[j] return j, path[j]
} }
return 0, path[0] return 0, path[0]
} }
// findPreciseIntersection calculates intersection points between seg and all four sides of the box, // findPreciseIntersection calculates intersection points between seg and all four sides of the box,
// then returns the intersection closest to seg.Start. // then returns the intersection closest to seg.Start.
func findPreciseIntersection(box *geo.Box, seg geo.Segment) *geo.Point { func findPreciseIntersection(box *geo.Box, seg geo.Segment) *geo.Point {
intersections := []struct { intersections := []struct {
point *geo.Point point *geo.Point
t float64 t float64
}{} }{}
left := box.TopLeft.X left := box.TopLeft.X
right := box.TopLeft.X + box.Width right := box.TopLeft.X + box.Width
top := box.TopLeft.Y top := box.TopLeft.Y
bottom := box.TopLeft.Y + box.Height bottom := box.TopLeft.Y + box.Height
dx := seg.End.X - seg.Start.X dx := seg.End.X - seg.Start.X
dy := seg.End.Y - seg.Start.Y dy := seg.End.Y - seg.Start.Y
// Check vertical boundaries. // Check vertical boundaries.
if dx != 0 { if dx != 0 {
t := (left - seg.Start.X) / dx t := (left - seg.Start.X) / dx
if t >= 0 && t <= 1 { if t >= 0 && t <= 1 {
y := seg.Start.Y + t*dy y := seg.Start.Y + t*dy
if y >= top && y <= bottom { if y >= top && y <= bottom {
intersections = append(intersections, struct { intersections = append(intersections, struct {
point *geo.Point point *geo.Point
t float64 t float64
}{geo.NewPoint(left, y), t}) }{geo.NewPoint(left, y), t})
} }
} }
t = (right - seg.Start.X) / dx t = (right - seg.Start.X) / dx
if t >= 0 && t <= 1 { if t >= 0 && t <= 1 {
y := seg.Start.Y + t*dy y := seg.Start.Y + t*dy
if y >= top && y <= bottom { if y >= top && y <= bottom {
intersections = append(intersections, struct { intersections = append(intersections, struct {
point *geo.Point point *geo.Point
t float64 t float64
}{geo.NewPoint(right, y), t}) }{geo.NewPoint(right, y), t})
} }
} }
} }
// Check horizontal boundaries. // Check horizontal boundaries.
if dy != 0 { if dy != 0 {
t := (top - seg.Start.Y) / dy t := (top - seg.Start.Y) / dy
if t >= 0 && t <= 1 { if t >= 0 && t <= 1 {
x := seg.Start.X + t*dx x := seg.Start.X + t*dx
if x >= left && x <= right { if x >= left && x <= right {
intersections = append(intersections, struct { intersections = append(intersections, struct {
point *geo.Point point *geo.Point
t float64 t float64
}{geo.NewPoint(x, top), t}) }{geo.NewPoint(x, top), t})
} }
} }
t = (bottom - seg.Start.Y) / dy t = (bottom - seg.Start.Y) / dy
if t >= 0 && t <= 1 { if t >= 0 && t <= 1 {
x := seg.Start.X + t*dx x := seg.Start.X + t*dx
if x >= left && x <= right { if x >= left && x <= right {
intersections = append(intersections, struct { intersections = append(intersections, struct {
point *geo.Point point *geo.Point
t float64 t float64
}{geo.NewPoint(x, bottom), t}) }{geo.NewPoint(x, bottom), t})
} }
} }
} }
if len(intersections) == 0 { if len(intersections) == 0 {
return nil return nil
} }
// Sort intersections by t (distance from seg.Start) and return the closest. // Sort intersections by t (distance from seg.Start) and return the closest.
sort.Slice(intersections, func(i, j int) bool { sort.Slice(intersections, func(i, j int) bool {
return intersections[i].t < intersections[j].t return intersections[i].t < intersections[j].t
}) })
return intersections[0].point return intersections[0].point
} }
// trimPathPoints removes intermediate points that fall inside the given box while preserving endpoints. // trimPathPoints removes intermediate points that fall inside the given box while preserving endpoints.
func trimPathPoints(path []*geo.Point, box *geo.Box) []*geo.Point { func trimPathPoints(path []*geo.Point, box *geo.Box) []*geo.Point {
if len(path) <= 2 { if len(path) <= 2 {
return path return path
} }
trimmed := []*geo.Point{path[0]} trimmed := []*geo.Point{path[0]}
for i := 1; i < len(path)-1; i++ { for i := 1; i < len(path)-1; i++ {
if !boxContains(box, path[i]) { if !boxContains(box, path[i]) {
trimmed = append(trimmed, path[i]) trimmed = append(trimmed, path[i])
} }
} }
trimmed = append(trimmed, path[len(path)-1]) trimmed = append(trimmed, path[len(path)-1])
return trimmed return trimmed
} }
// boxContains uses strict inequalities so that points exactly on the boundary are considered outside. // boxContains uses strict inequalities so that points exactly on the boundary are considered outside.
func boxContains(b *geo.Box, p *geo.Point) bool { func boxContains(b *geo.Box, p *geo.Point) bool {
return p.X > b.TopLeft.X && return p.X > b.TopLeft.X &&
p.X < b.TopLeft.X+b.Width && p.X < b.TopLeft.X+b.Width &&
p.Y > b.TopLeft.Y && p.Y > b.TopLeft.Y &&
p.Y < b.TopLeft.Y+b.Height p.Y < b.TopLeft.Y+b.Height
} }
func positionLabelsIcons(obj *d2graph.Object) { func positionLabelsIcons(obj *d2graph.Object) {
if obj.Icon != nil && obj.IconPosition == nil { if obj.Icon != nil && obj.IconPosition == nil {
if len(obj.ChildrenArray) > 0 { if len(obj.ChildrenArray) > 0 {
obj.IconPosition = go2.Pointer(label.OutsideTopLeft.String()) obj.IconPosition = go2.Pointer(label.OutsideTopLeft.String())
if obj.LabelPosition == nil { if obj.LabelPosition == nil {
obj.LabelPosition = go2.Pointer(label.OutsideTopRight.String()) obj.LabelPosition = go2.Pointer(label.OutsideTopRight.String())
return return
} }
} else if obj.SQLTable != nil || obj.Class != nil || obj.Language != "" { } else if obj.SQLTable != nil || obj.Class != nil || obj.Language != "" {
obj.IconPosition = go2.Pointer(label.OutsideTopLeft.String()) obj.IconPosition = go2.Pointer(label.OutsideTopLeft.String())
} else { } else {
obj.IconPosition = go2.Pointer(label.InsideMiddleCenter.String()) obj.IconPosition = go2.Pointer(label.InsideMiddleCenter.String())
} }
} }
if obj.HasLabel() && obj.LabelPosition == nil { if obj.HasLabel() && obj.LabelPosition == nil {
if len(obj.ChildrenArray) > 0 { if len(obj.ChildrenArray) > 0 {
obj.LabelPosition = go2.Pointer(label.OutsideTopCenter.String()) obj.LabelPosition = go2.Pointer(label.OutsideTopCenter.String())
} else if obj.HasOutsideBottomLabel() { } else if obj.HasOutsideBottomLabel() {
obj.LabelPosition = go2.Pointer(label.OutsideBottomCenter.String()) obj.LabelPosition = go2.Pointer(label.OutsideBottomCenter.String())
} else if obj.Icon != nil { } else if obj.Icon != nil {
obj.LabelPosition = go2.Pointer(label.InsideTopCenter.String()) obj.LabelPosition = go2.Pointer(label.InsideTopCenter.String())
} else { } else {
obj.LabelPosition = go2.Pointer(label.InsideMiddleCenter.String()) obj.LabelPosition = go2.Pointer(label.InsideMiddleCenter.String())
} }
if float64(obj.LabelDimensions.Width) > obj.Width || if float64(obj.LabelDimensions.Width) > obj.Width ||
float64(obj.LabelDimensions.Height) > obj.Height { float64(obj.LabelDimensions.Height) > obj.Height {
if len(obj.ChildrenArray) > 0 { if len(obj.ChildrenArray) > 0 {
obj.LabelPosition = go2.Pointer(label.OutsideTopCenter.String()) obj.LabelPosition = go2.Pointer(label.OutsideTopCenter.String())
} else { } else {
obj.LabelPosition = go2.Pointer(label.OutsideBottomCenter.String()) obj.LabelPosition = go2.Pointer(label.OutsideBottomCenter.String())
} }
} }
} }
} }