This commit is contained in:
Alexander Wang 2022-11-29 13:19:03 -08:00
parent d7fb9c16b8
commit 10e13c3716
No known key found for this signature in database
GPG key ID: D89FA31966BDBECE
3 changed files with 56 additions and 10 deletions

View file

@ -205,7 +205,7 @@ func compile(ctx context.Context, ms *xmain.State, isWatching bool, plugin d2plu
if err != nil { if err != nil {
return nil, err return nil, err
} }
svg, err = imgbundler.InlineLocal(svg) svg, err = imgbundler.InlineLocal(ms, svg)
if err != nil { if err != nil {
// Missing/broken images are fine during watch mode, as the user is likely building up a diagram. // Missing/broken images are fine during watch mode, as the user is likely building up a diagram.
// Otherwise, the assumption is that this diagram is building for production, and broken images are not okay. // Otherwise, the assumption is that this diagram is building for production, and broken images are not okay.
@ -217,7 +217,7 @@ func compile(ctx context.Context, ms *xmain.State, isWatching bool, plugin d2plu
out := svg out := svg
if filepath.Ext(outputPath) == ".png" { if filepath.Ext(outputPath) == ".png" {
svg, err = imgbundler.InlineRemote(svg) svg, err = imgbundler.InlineRemote(ms, svg)
if err != nil { if err != nil {
if !isWatching { if !isWatching {
return nil, err return nil, err

View file

@ -16,6 +16,8 @@ import (
"go.uber.org/multierr" "go.uber.org/multierr"
"oss.terrastruct.com/xdefer" "oss.terrastruct.com/xdefer"
"oss.terrastruct.com/d2/lib/xmain"
) )
var imageRe = regexp.MustCompile(`<image href="([^"]+)"`) var imageRe = regexp.MustCompile(`<image href="([^"]+)"`)
@ -26,15 +28,15 @@ type resp struct {
err error err error
} }
func InlineLocal(in []byte) ([]byte, error) { func InlineLocal(ms *xmain.State, in []byte) ([]byte, error) {
return inline(in, false) return inline(ms, in, false)
} }
func InlineRemote(in []byte) ([]byte, error) { func InlineRemote(ms *xmain.State, in []byte) ([]byte, error) {
return inline(in, true) return inline(ms, in, true)
} }
func inline(svg []byte, isRemote bool) (_ []byte, err error) { func inline(ms *xmain.State, svg []byte, isRemote bool) (_ []byte, err error) {
defer xdefer.Errorf(&err, "failed to bundle images") defer xdefer.Errorf(&err, "failed to bundle images")
imgs := imageRe.FindAllSubmatch(svg, -1) imgs := imageRe.FindAllSubmatch(svg, -1)
@ -93,10 +95,13 @@ func inline(svg []byte, isRemote bool) (_ []byte, err error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
ms.Log.Debug.Printf("there")
return nil, fmt.Errorf("failed waiting for imgbundler workers: %w", ctx.Err()) return nil, fmt.Errorf("failed waiting for imgbundler workers: %w", ctx.Err())
case <-time.After(time.Second * 5):
ms.Log.Info.Printf("fetching images...")
case resp, ok := <-respChan: case resp, ok := <-respChan:
if !ok { if !ok {
return svg, nil return svg, err
} }
if resp.err != nil { if resp.err != nil {
err = multierr.Combine(err, resp.err) err = multierr.Combine(err, resp.err)
@ -124,6 +129,9 @@ func fetch(ctx context.Context, href string) (string, error) {
return "", err return "", err
} }
defer imgResp.Body.Close() defer imgResp.Body.Close()
if imgResp.StatusCode != 200 {
return "", fmt.Errorf("img %s returned status code %d", href, imgResp.StatusCode)
}
data, err := ioutil.ReadAll(imgResp.Body) data, err := ioutil.ReadAll(imgResp.Body)
if err != nil { if err != nil {
return "", err return "", err

View file

@ -5,9 +5,15 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"path/filepath" "path/filepath"
"strings" "strings"
"testing" "testing"
"oss.terrastruct.com/cmdlog"
"oss.terrastruct.com/xos"
"oss.terrastruct.com/d2/lib/xmain"
) )
//go:embed test_png.png //go:embed test_png.png
@ -70,6 +76,17 @@ width="328" height="587" viewBox="-100 -131 328 587"><style type="text/css">
}]]></style></svg> }]]></style></svg>
`, svgURL, pngURL) `, svgURL, pngURL)
ms := &xmain.State{
Name: "test",
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
Env: xos.NewEnv(os.Environ()),
}
ms.Log = cmdlog.Log(ms.Env, os.Stderr)
transport = roundTripFunc(func(req *http.Request) *http.Response { transport = roundTripFunc(func(req *http.Request) *http.Response {
respRecorder := httptest.NewRecorder() respRecorder := httptest.NewRecorder()
switch req.URL.String() { switch req.URL.String() {
@ -84,7 +101,7 @@ width="328" height="587" viewBox="-100 -131 328 587"><style type="text/css">
return respRecorder.Result() return respRecorder.Result()
}) })
out, err := InlineRemote([]byte(sampleSVG)) out, err := InlineRemote(ms, []byte(sampleSVG))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -97,6 +114,17 @@ width="328" height="587" viewBox="-100 -131 328 587"><style type="text/css">
if !strings.Contains(string(out), "image/png") { if !strings.Contains(string(out), "image/png") {
t.Fatal("no png image inserted") t.Fatal("no png image inserted")
} }
// Test error response
transport = roundTripFunc(func(req *http.Request) *http.Response {
respRecorder := httptest.NewRecorder()
respRecorder.WriteHeader(500)
return respRecorder.Result()
})
_, err = InlineRemote(ms, []byte(sampleSVG))
if err == nil {
t.Fatal("expected error")
}
} }
func TestInlineLocal(t *testing.T) { func TestInlineLocal(t *testing.T) {
@ -135,7 +163,17 @@ width="328" height="587" viewBox="-100 -131 328 587"><style type="text/css">
}]]></style></svg> }]]></style></svg>
`, svgURL, pngURL) `, svgURL, pngURL)
out, err := InlineLocal([]byte(sampleSVG)) ms := &xmain.State{
Name: "test",
Stdin: os.Stdin,
Stdout: os.Stdout,
Stderr: os.Stderr,
Env: xos.NewEnv(os.Environ()),
}
ms.Log = cmdlog.Log(ms.Env, os.Stderr)
out, err := InlineLocal(ms, []byte(sampleSVG))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }