diff --git a/lib/imgbundler/imgbundler.go b/lib/imgbundler/imgbundler.go index 207809c2e..72a109697 100644 --- a/lib/imgbundler/imgbundler.go +++ b/lib/imgbundler/imgbundler.go @@ -6,9 +6,11 @@ import ( "encoding/base64" "fmt" "io/ioutil" + "mime" "net/http" "net/url" "os" + "path" "regexp" "strings" "sync" @@ -38,7 +40,11 @@ type repl struct { } func bundle(ctx context.Context, ms *xmain.State, svg []byte, isRemote bool) (_ []byte, err error) { - defer xdefer.Errorf(&err, "failed to bundle images") + if isRemote { + defer xdefer.Errorf(&err, "failed to bundle remote images") + } else { + defer xdefer.Errorf(&err, "failed to bundle local images") + } imgs := imageRegex.FindAllSubmatch(svg, -1) imgs = filterImageElements(imgs, isRemote) @@ -65,9 +71,10 @@ func bundle(ctx context.Context, ms *xmain.State, svg []byte, isRemote bool) (_ }() var buf []byte + var mimeType string var err error if isRemote { - buf, err = httpGet(ctx, string(href)) + buf, mimeType, err = httpGet(ctx, string(href)) } else { buf, err = os.ReadFile(string(href)) } @@ -79,7 +86,9 @@ func bundle(ctx context.Context, ms *xmain.State, svg []byte, isRemote bool) (_ return } - mimeType := http.DetectContentType(buf) + if mimeType == "" { + mimeType = sniffMimeType(href, buf, isRemote) + } mimeType = strings.Replace(mimeType, "text/xml", "image/svg+xml", 1) b64 := base64.StdEncoding.EncodeToString(buf) @@ -108,7 +117,7 @@ func bundle(ctx context.Context, ms *xmain.State, svg []byte, isRemote bool) (_ case repl, ok := <-replc: if !ok { if len(errhrefs) > 0 { - return svg, xerrors.Errorf("failed to bundle the following images: %v", errhrefs) + return svg, xerrors.Errorf("%v", errhrefs) } return svg, nil } @@ -141,23 +150,45 @@ func filterImageElements(imgs [][][]byte, isRemote bool) [][][]byte { var httpClient = &http.Client{} -func httpGet(ctx context.Context, href string) ([]byte, error) { +func httpGet(ctx context.Context, href string) ([]byte, string, error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", href, nil) if err != nil { - return nil, err + return nil, "", err } resp, err := httpClient.Do(req) if err != nil { - return nil, err + return nil, "", err } defer resp.Body.Close() if resp.StatusCode != 200 { - return nil, fmt.Errorf("expected status 200 but got %d %s", resp.StatusCode, resp.Status) + return nil, "", fmt.Errorf("expected status 200 but got %d %s", resp.StatusCode, resp.Status) } r := http.MaxBytesReader(nil, resp.Body, maxImageSize) - return ioutil.ReadAll(r) + buf, err := ioutil.ReadAll(r) + if err != nil { + return nil, "", err + } + return buf, resp.Header.Get("Content-Type"), nil +} + +// sniffMimeType sniffs the mime type of href based on its file extension and contents. +func sniffMimeType(href, buf []byte, isRemote bool) string { + p := string(href) + if isRemote { + u, err := url.Parse(p) + if err != nil { + p = "" + } else { + p = u.Path + } + } + mimeType := mime.TypeByExtension(path.Ext(p)) + if mimeType == "" { + mimeType = http.DetectContentType(buf) + } + return mimeType }