package textmeasure import ( "bytes" "fmt" "math" "strings" "github.com/PuerkitoBio/goquery" "github.com/yuin/goldmark" "github.com/yuin/goldmark/extension" goldmarkHtml "github.com/yuin/goldmark/renderer/html" "golang.org/x/net/html" "oss.terrastruct.com/util-go/go2" "oss.terrastruct.com/d2/d2renderers/d2fonts" ) var markdownRenderer goldmark.Markdown // these are css values from github-markdown.css so we can accurately compute the rendered dimensions const ( MarkdownFontSize = d2fonts.FONT_SIZE_M MarkdownLineHeight = 1.5 PaddingLeft_ul_ol_em = 2. MarginBottom_ul = 16. MarginTop_li_p = 16. MarginTop_li_em = 0.25 MarginBottom_p = 16. LineHeight_h = 1.25 MarginTop_h = 24 MarginBottom_h = 16 PaddingBottom_h1_h2_em = 0.3 BorderBottom_h1_h2 = 1 Height_hr_em = 0.25 MarginTopBottom_hr = 24 Padding_pre = 16 MarginBottom_pre = 16 LineHeight_pre = 1.45 FontSize_pre_code_em = 0.85 PaddingTopBottom_code_em = 0.2 PaddingLeftRight_code_em = 0.4 PaddingLR_blockquote_em = 1. MarginBottom_blockquote = 16 BorderLeft_blockquote_em = 0.25 h1_em = 2. h2_em = 1.5 h3_em = 1.25 h4_em = 1. h5_em = 0.875 h6_em = 0.85 ) func HeaderToFontSize(baseFontSize int, header string) int { switch header { case "h1": return int(h1_em * float64(baseFontSize)) case "h2": return int(h2_em * float64(baseFontSize)) case "h3": return int(h3_em * float64(baseFontSize)) case "h4": return int(h4_em * float64(baseFontSize)) case "h5": return int(h5_em * float64(baseFontSize)) case "h6": return int(h6_em * float64(baseFontSize)) } return 0 } func RenderMarkdown(m string) (string, error) { var output bytes.Buffer if err := markdownRenderer.Convert([]byte(m), &output); err != nil { return "", err } sanitized, err := sanitizeLinks(output.String()) if err != nil { return "", err } return sanitized, nil } func init() { markdownRenderer = goldmark.New( goldmark.WithRendererOptions( goldmarkHtml.WithUnsafe(), goldmarkHtml.WithXHTML(), ), goldmark.WithExtensions( extension.Strikethrough, extension.Table, ), ) } func MeasureMarkdown(mdText string, ruler *Ruler, fontFamily *d2fonts.FontFamily, fontSize int) (width, height int, err error) { render, err := RenderMarkdown(mdText) if err != nil { return width, height, err } doc, err := goquery.NewDocumentFromReader(strings.NewReader(render)) if err != nil { return width, height, err } { originalLineHeight := ruler.LineHeightFactor ruler.boundsWithDot = true ruler.LineHeightFactor = MarkdownLineHeight defer func() { ruler.LineHeightFactor = originalLineHeight ruler.boundsWithDot = false }() } // TODO consider setting a max width + (manual) text wrapping bodyNode := doc.Find("body").First().Nodes[0] bodyAttrs := ruler.measureNode(0, bodyNode, fontFamily, fontSize, d2fonts.FONT_STYLE_REGULAR) return int(math.Ceil(bodyAttrs.width)), int(math.Ceil(bodyAttrs.height)), nil } func hasPrev(n *html.Node) bool { if n.PrevSibling == nil { return false } if strings.TrimSpace(n.PrevSibling.Data) == "" { return hasPrev(n.PrevSibling) } return true } func hasNext(n *html.Node) bool { if n.NextSibling == nil { return false } // skip over empty text nodes if strings.TrimSpace(n.NextSibling.Data) == "" { return hasNext(n.NextSibling) } return true } func getPrev(n *html.Node) *html.Node { if n == nil { return nil } if strings.TrimSpace(n.Data) == "" { if next := getNext(n.PrevSibling); next != nil { return next } } return n } func getNext(n *html.Node) *html.Node { if n == nil { return nil } if strings.TrimSpace(n.Data) == "" { if next := getNext(n.NextSibling); next != nil { return next } } return n } func isBlockElement(elType string) bool { switch elType { case "blockquote", "div", "h1", "h2", "h3", "h4", "h5", "h6", "hr", "li", "ol", "p", "pre", "ul", "table", "thead", "tbody", "tr", "td", "th": // Added table elements here return true default: return false } } func hasAncestorElement(n *html.Node, elType string) bool { if n.Parent == nil { return false } if n.Parent.Type == html.ElementNode && n.Parent.Data == elType { return true } return hasAncestorElement(n.Parent, elType) } type blockAttrs struct { width, height, marginTop, marginBottom float64 extraData interface{} } func (b *blockAttrs) isNotEmpty() bool { return b != nil && *b != blockAttrs{} } // measures node dimensions to match rendering with styles in github-markdown.css func (ruler *Ruler) measureNode(depth int, n *html.Node, fontFamily *d2fonts.FontFamily, fontSize int, fontStyle d2fonts.FontStyle) blockAttrs { if fontFamily == nil { fontFamily = go2.Pointer(d2fonts.SourceSansPro) } font := fontFamily.Font(fontSize, fontStyle) var parentElementType string if n.Parent != nil && n.Parent.Type == html.ElementNode { parentElementType = n.Parent.Data } debugMeasure := false var depthStr string if debugMeasure { if depth == 0 { fmt.Println() } depthStr = "┌" for i := 0; i < depth; i++ { depthStr += "-" } } switch n.Type { case html.TextNode: if strings.Trim(n.Data, "\n\t\b") == "" { return blockAttrs{} } str := n.Data isCode := parentElementType == "pre" || parentElementType == "code" spaceWidths := 0. if !isCode { spaceWidth := ruler.spaceWidth(font) // MeasurePrecise will not include leading or trailing whitespace, so we account for it here str = strings.ReplaceAll(str, "\n", " ") str = strings.ReplaceAll(str, "\t", " ") if strings.HasPrefix(str, " ") { // consecutive leading/trailing spaces end up rendered as a single space str = strings.TrimPrefix(str, " ") if hasPrev(n) { spaceWidths += spaceWidth } } if strings.HasSuffix(str, " ") { str = strings.TrimSuffix(str, " ") if hasNext(n) { spaceWidths += spaceWidth } } } if parentElementType == "pre" { originalLineHeight := ruler.LineHeightFactor ruler.LineHeightFactor = LineHeight_pre defer func() { ruler.LineHeightFactor = originalLineHeight }() } w, h := ruler.MeasurePrecise(font, str) if isCode { w *= FontSize_pre_code_em h *= FontSize_pre_code_em } else { w = ruler.scaleUnicode(w, font, str) } if debugMeasure { fmt.Printf("%stext(%v,%v)\n", depthStr, w, h) } return blockAttrs{w + spaceWidths, h, 0, 0, 0} case html.ElementNode: isCode := false switch n.Data { case "h1", "h2", "h3", "h4", "h5", "h6": fontSize = HeaderToFontSize(fontSize, n.Data) fontStyle = d2fonts.FONT_STYLE_SEMIBOLD originalLineHeight := ruler.LineHeightFactor ruler.LineHeightFactor = LineHeight_h defer func() { ruler.LineHeightFactor = originalLineHeight }() case "em": fontStyle = d2fonts.FONT_STYLE_ITALIC case "b", "strong": fontStyle = d2fonts.FONT_STYLE_BOLD case "pre", "code": fontFamily = go2.Pointer(d2fonts.SourceCodePro) fontStyle = d2fonts.FONT_STYLE_REGULAR isCode = true } block := blockAttrs{} lineHeightPx := float64(fontSize) * ruler.LineHeightFactor if n.FirstChild != nil { first := getNext(n.FirstChild) last := getPrev(n.LastChild) var blocks []blockAttrs var inlineBlock *blockAttrs // first create blocks from combined inline elements, then combine all blocks // inlineBlock will be non-nil while inline elements are being combined into a block endInlineBlock := func() { if !isCode && inlineBlock.height > 0 && inlineBlock.height < lineHeightPx { inlineBlock.height = lineHeightPx } blocks = append(blocks, *inlineBlock) inlineBlock = nil } for child := n.FirstChild; child != nil; child = child.NextSibling { childBlock := ruler.measureNode(depth+1, child, fontFamily, fontSize, fontStyle) if child.Type == html.ElementNode && isBlockElement(child.Data) { if inlineBlock != nil { endInlineBlock() } newBlock := &blockAttrs{} newBlock.width = childBlock.width newBlock.height = childBlock.height if child == first && n.Data == "blockquote" { newBlock.marginTop = 0. } else { newBlock.marginTop = childBlock.marginTop } if child == last && n.Data == "blockquote" { newBlock.marginBottom = 0. } else { newBlock.marginBottom = childBlock.marginBottom } blocks = append(blocks, *newBlock) } else if child.Type == html.ElementNode && child.Data == "br" { if inlineBlock != nil { endInlineBlock() } else { block.height += lineHeightPx } } else if childBlock.isNotEmpty() { if inlineBlock == nil { // start inline block with child inlineBlock = &childBlock } else { // stack inline element dimensions horizontally inlineBlock.width += childBlock.width inlineBlock.height = go2.Max(inlineBlock.height, childBlock.height) inlineBlock.marginTop = go2.Max(inlineBlock.marginTop, childBlock.marginTop) inlineBlock.marginBottom = go2.Max(inlineBlock.marginBottom, childBlock.marginBottom) } } } if inlineBlock != nil { endInlineBlock() } var prevMarginBottom float64 for i, b := range blocks { if i == 0 { block.marginTop = go2.Max(block.marginTop, b.marginTop) } else { marginDiff := b.marginTop - prevMarginBottom if marginDiff > 0 { block.height += marginDiff } } if i == len(blocks)-1 { block.marginBottom = go2.Max(block.marginBottom, b.marginBottom) } else { block.height += b.marginBottom prevMarginBottom = b.marginBottom } block.height += b.height block.width = go2.Max(block.width, b.width) } } switch n.Data { case "blockquote": block.width += (2*PaddingLR_blockquote_em + BorderLeft_blockquote_em) * float64(fontSize) block.marginBottom = go2.Max(block.marginBottom, MarginBottom_blockquote) case "p": if parentElementType == "li" { block.marginTop = go2.Max(block.marginTop, MarginTop_li_p) } block.marginBottom = go2.Max(block.marginBottom, MarginBottom_p) case "h1", "h2", "h3", "h4", "h5", "h6": block.marginTop = go2.Max(block.marginTop, MarginTop_h) block.marginBottom = go2.Max(block.marginBottom, MarginBottom_h) switch n.Data { case "h1", "h2": block.height += PaddingBottom_h1_h2_em*float64(fontSize) + BorderBottom_h1_h2 } case "li": block.width += PaddingLeft_ul_ol_em * float64(fontSize) if hasPrev(n) { block.marginTop = go2.Max(block.marginTop, MarginTop_li_em*float64(fontSize)) } case "ol", "ul": if hasAncestorElement(n, "ul") || hasAncestorElement(n, "ol") { block.marginTop = 0 block.marginBottom = 0 } else { block.marginBottom = go2.Max(block.marginBottom, MarginBottom_ul) } case "pre": block.width += 2 * Padding_pre block.height += 2 * Padding_pre block.marginBottom = go2.Max(block.marginBottom, MarginBottom_pre) case "code": if parentElementType != "pre" { block.width += 2 * PaddingLeftRight_code_em * float64(fontSize) block.height += 2 * PaddingTopBottom_code_em * float64(fontSize) } case "hr": block.height += Height_hr_em * float64(fontSize) block.marginTop = go2.Max(block.marginTop, MarginTopBottom_hr) block.marginBottom = go2.Max(block.marginBottom, MarginTopBottom_hr) case "table": var columnWidths []float64 var tableHeight float64 // Border width for table (outer border) tableBorder := 1.0 // Iterate over child nodes (tbody, thead, tr) for child := n.FirstChild; child != nil; child = child.NextSibling { if child.Type == html.ElementNode && (child.Data == "tbody" || child.Data == "thead" || child.Data == "tfoot") { childAttrs := ruler.measureNode(depth+1, child, fontFamily, fontSize, fontStyle) tableHeight += childAttrs.height if childColumnWidths, ok := childAttrs.extraData.([][]float64); ok { columnWidths = mergeColumnWidths(columnWidths, childColumnWidths) } } else if child.Type == html.ElementNode && child.Data == "tr" { rowAttrs := ruler.measureNode(depth+1, child, fontFamily, fontSize, fontStyle) tableHeight += rowAttrs.height if rowCellWidths, ok := rowAttrs.extraData.([]float64); ok { columnWidths = mergeColumnWidths(columnWidths, [][]float64{rowCellWidths}) } } } // Calculate total table width including ALL borders tableWidth := 0.0 if len(columnWidths) > 0 { // Add widths of all columns for _, colWidth := range columnWidths { tableWidth += colWidth } // Add border for every column division (including outer borders) tableWidth += float64(len(columnWidths)+1) * tableBorder } // Add outer borders to height tableHeight += 2 * tableBorder block.width = tableWidth block.height = tableHeight case "thead", "tbody", "tfoot": var sectionWidth, sectionHeight float64 var sectionColumnWidths [][]float64 // Iterate over tr elements for child := n.FirstChild; child != nil; child = child.NextSibling { if child.Type == html.ElementNode && child.Data == "tr" { childAttrs := ruler.measureNode(depth+1, child, fontFamily, fontSize, fontStyle) sectionHeight += childAttrs.height sectionWidth = go2.Max(sectionWidth, childAttrs.width) if rowCellWidths, ok := childAttrs.extraData.([]float64); ok { sectionColumnWidths = append(sectionColumnWidths, rowCellWidths) } } } block.width = sectionWidth block.height = sectionHeight block.extraData = sectionColumnWidths // Pass column widths back to table case "td", "th": // Apply semibold style to header cells cellFontStyle := fontStyle if n.Data == "th" { cellFontStyle = d2fonts.FONT_STYLE_SEMIBOLD } // Measure cell content with appropriate font style var cellContentWidth, cellContentHeight float64 for child := n.FirstChild; child != nil; child = child.NextSibling { // Pass the header-specific font style to child measurements childAttrs := ruler.measureNode(depth+1, child, fontFamily, fontSize, cellFontStyle) cellContentWidth = go2.Max(cellContentWidth, childAttrs.width) cellContentHeight += childAttrs.height } block.width = cellContentWidth block.height = cellContentHeight case "tr": var rowWidth, rowHeight float64 var cellWidths []float64 cellBorder := 1.0 rowBorder := 1.0 maxCellHeight := 0.0 cellCount := 0 // Check if this row is in a thead to determine default font style for cells inHeader := hasAncestorElement(n, "thead") rowFontStyle := fontStyle if inHeader { rowFontStyle = d2fonts.FONT_STYLE_SEMIBOLD } for child := n.FirstChild; child != nil; child = child.NextSibling { if child.Type == html.ElementNode && (child.Data == "td" || child.Data == "th") { cellCount++ // Use semibold for th elements regardless of location childFontStyle := rowFontStyle if child.Data == "th" { childFontStyle = d2fonts.FONT_STYLE_SEMIBOLD } childAttrs := ruler.measureNode(depth+1, child, fontFamily, fontSize, childFontStyle) cellPaddingH := 13.0 * 2 cellPaddingV := 6.0 * 2 cellWidth := childAttrs.width + cellPaddingH cellHeight := childAttrs.height + cellPaddingV cellWidths = append(cellWidths, cellWidth) maxCellHeight = go2.Max(maxCellHeight, cellHeight) } } if cellCount > 0 { for _, w := range cellWidths { rowWidth += w } rowWidth += float64(cellCount+1) * cellBorder } rowHeight = maxCellHeight + rowBorder block.width = rowWidth block.height = rowHeight block.extraData = cellWidths } if block.height > 0 && block.height < lineHeightPx { block.height = lineHeightPx } if debugMeasure { fmt.Printf("%s%s(%v,%v) mt:%v mb:%v\n", depthStr, n.Data, block.width, block.height, block.marginTop, block.marginBottom) } return block } return blockAttrs{} } func mergeColumnWidths(existing []float64, new [][]float64) []float64 { for _, rowWidths := range new { for i, width := range rowWidths { if i >= len(existing) { existing = append(existing, width) } else { existing[i] = go2.Max(existing[i], width) } } } return existing }