cli: watch imported files

This commit is contained in:
Alexander Wang 2023-11-08 18:34:01 -08:00
parent 6324d67516
commit f328eb3b9f
No known key found for this signature in database
GPG key ID: D89FA31966BDBECE
3 changed files with 195 additions and 43 deletions

View file

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/fs"
"os" "os"
"os/exec" "os/exec"
"os/user" "os/user"
@ -332,7 +333,7 @@ func Run(ctx context.Context, ms *xmain.State) (err error) {
ctx, cancel := timelib.WithTimeout(ctx, time.Minute*2) ctx, cancel := timelib.WithTimeout(ctx, time.Minute*2)
defer cancel() defer cancel()
_, written, err := compile(ctx, ms, plugins, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page) _, written, err := compile(ctx, ms, plugins, nil, layoutFlag, renderOpts, fontFamily, *animateIntervalFlag, inputPath, outputPath, "", *bundleFlag, *forceAppendixFlag, pw.Page)
if err != nil { if err != nil {
if written { if written {
return fmt.Errorf("failed to fully compile (partial render written) %s: %w", ms.HumanPath(inputPath), err) return fmt.Errorf("failed to fully compile (partial render written) %s: %w", ms.HumanPath(inputPath), err)
@ -367,7 +368,7 @@ func LayoutResolver(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plu
} }
} }
func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) { func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, fs fs.FS, layout *string, renderOpts d2svg.RenderOpts, fontFamily *d2fonts.FontFamily, animateInterval int64, inputPath, outputPath, boardPath string, bundle, forceAppendix bool, page playwright.Page) (_ []byte, written bool, _ error) {
start := time.Now() start := time.Now()
input, err := ms.ReadPath(inputPath) input, err := ms.ReadPath(inputPath)
if err != nil { if err != nil {
@ -385,6 +386,7 @@ func compile(ctx context.Context, ms *xmain.State, plugins []d2plugin.Plugin, la
InputPath: inputPath, InputPath: inputPath,
LayoutResolver: LayoutResolver(ctx, ms, plugins), LayoutResolver: LayoutResolver(ctx, ms, plugins),
Layout: layout, Layout: layout,
FS: fs,
} }
cancel := background.Repeat(func() { cancel := background.Repeat(func() {

View file

@ -12,6 +12,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"sort"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -218,10 +219,13 @@ func (w *watcher) goFunc(fn func(context.Context) error) {
* TODO: Abstract out file system and fsnotify to test this with 100% coverage. See comment in main_test.go * TODO: Abstract out file system and fsnotify to test this with 100% coverage. See comment in main_test.go
*/ */
func (w *watcher) watchLoop(ctx context.Context) error { func (w *watcher) watchLoop(ctx context.Context) error {
lastModified, err := w.ensureAddWatch(ctx) lastModified := make(map[string]time.Time)
mt, err := w.ensureAddWatch(ctx, w.inputPath)
if err != nil { if err != nil {
return err return err
} }
lastModified[w.inputPath] = mt
w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath)) w.ms.Log.Info.Printf("compiling %v...", w.ms.HumanPath(w.inputPath))
w.requestCompile() w.requestCompile()
@ -230,6 +234,8 @@ func (w *watcher) watchLoop(ctx context.Context) error {
pollTicker := time.NewTicker(time.Second * 10) pollTicker := time.NewTicker(time.Second * 10)
defer pollTicker.Stop() defer pollTicker.Stop()
changed := make(map[string]struct{})
for { for {
select { select {
case <-pollTicker.C: case <-pollTicker.C:
@ -237,13 +243,18 @@ func (w *watcher) watchLoop(ctx context.Context) error {
// getting any more events. // getting any more events.
// File notification APIs are notoriously unreliable. I've personally experienced // File notification APIs are notoriously unreliable. I've personally experienced
// many quirks and so feel this check is justified even if excessive. // many quirks and so feel this check is justified even if excessive.
mt, err := w.ensureAddWatch(ctx) missedChanges := false
for _, watched := range w.fw.WatchList() {
mt, err := w.ensureAddWatch(ctx, watched)
if err != nil { if err != nil {
return err return err
} }
if !mt.Equal(lastModified) { if mt2, ok := lastModified[watched]; !ok || !mt.Equal(mt2) {
// We missed changes. missedChanges = true
lastModified = mt lastModified[watched] = mt
}
}
if missedChanges {
w.requestCompile() w.requestCompile()
} }
case ev, ok := <-w.fw.Events: case ev, ok := <-w.fw.Events:
@ -251,19 +262,20 @@ func (w *watcher) watchLoop(ctx context.Context) error {
return errors.New("fsnotify watcher closed") return errors.New("fsnotify watcher closed")
} }
w.ms.Log.Debug.Printf("received file system event %v", ev) w.ms.Log.Debug.Printf("received file system event %v", ev)
mt, err := w.ensureAddWatch(ctx) mt, err := w.ensureAddWatch(ctx, ev.Name)
if err != nil { if err != nil {
return err return err
} }
if ev.Op == fsnotify.Chmod { if ev.Op == fsnotify.Chmod {
if mt.Equal(lastModified) { if mt.Equal(lastModified[ev.Name]) {
// Benign Chmod. // Benign Chmod.
// See https://github.com/fsnotify/fsnotify/issues/15 // See https://github.com/fsnotify/fsnotify/issues/15
continue continue
} }
// We missed changes. // We missed changes.
lastModified = mt lastModified[ev.Name] = mt
} }
changed[ev.Name] = struct{}{}
// The purpose of eatBurstTimer is to wait at least 16 milliseconds after a sequence of // The purpose of eatBurstTimer is to wait at least 16 milliseconds after a sequence of
// events to ensure that whomever is editing the file is now done. // events to ensure that whomever is editing the file is now done.
// //
@ -276,8 +288,18 @@ func (w *watcher) watchLoop(ctx context.Context) error {
// misleading error. // misleading error.
eatBurstTimer.Reset(time.Millisecond * 16) eatBurstTimer.Reset(time.Millisecond * 16)
case <-eatBurstTimer.C: case <-eatBurstTimer.C:
w.ms.Log.Info.Printf("detected change in %v: recompiling...", w.ms.HumanPath(w.inputPath)) var changedList []string
for k := range changed {
changedList = append(changedList, k)
}
sort.Strings(changedList)
changedStr := w.ms.HumanPath(changedList[0])
for i := 1; i < len(changed); i++ {
changedStr += fmt.Sprintf(", %s", w.ms.HumanPath(changedList[i]))
}
w.ms.Log.Info.Printf("detected change in %s: recompiling...", changedStr)
w.requestCompile() w.requestCompile()
changed = make(map[string]struct{})
case err, ok := <-w.fw.Errors: case err, ok := <-w.fw.Errors:
if !ok { if !ok {
return errors.New("fsnotify watcher closed") return errors.New("fsnotify watcher closed")
@ -296,17 +318,17 @@ func (w *watcher) requestCompile() {
} }
} }
func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) { func (w *watcher) ensureAddWatch(ctx context.Context, path string) (time.Time, error) {
interval := time.Millisecond * 16 interval := time.Millisecond * 16
tc := time.NewTimer(0) tc := time.NewTimer(0)
<-tc.C <-tc.C
for { for {
mt, err := w.addWatch(ctx) mt, err := w.addWatch(ctx, path)
if err == nil { if err == nil {
return mt, nil return mt, nil
} }
if interval >= time.Second { if interval >= time.Second {
w.ms.Log.Error.Printf("failed to watch inputPath %q: %v (retrying in %v)", w.ms.HumanPath(w.inputPath), err, interval) w.ms.Log.Error.Printf("failed to watch %q: %v (retrying in %v)", w.ms.HumanPath(path), err, interval)
} }
tc.Reset(interval) tc.Reset(interval)
@ -324,19 +346,56 @@ func (w *watcher) ensureAddWatch(ctx context.Context) (time.Time, error) {
} }
} }
func (w *watcher) addWatch(ctx context.Context) (time.Time, error) { func (w *watcher) addWatch(ctx context.Context, path string) (time.Time, error) {
err := w.fw.Add(w.inputPath) err := w.fw.Add(path)
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
var d os.FileInfo var d os.FileInfo
d, err = os.Stat(w.inputPath) d, err = os.Stat(path)
if err != nil { if err != nil {
return time.Time{}, err return time.Time{}, err
} }
return d.ModTime(), nil return d.ModTime(), nil
} }
func (w *watcher) replaceWatchList(ctx context.Context, paths []string) error {
// First remove the files no longer being watched
for _, watched := range w.fw.WatchList() {
if watched == w.inputPath {
continue
}
found := false
for _, p := range paths {
if watched == p {
found = true
break
}
}
if !found {
// Don't mind errors here
w.fw.Remove(watched)
}
}
// Then add the files newly being watched
for _, p := range paths {
found := false
for _, watched := range w.fw.WatchList() {
if watched == p {
found = true
break
}
}
if !found {
_, err := w.ensureAddWatch(ctx, p)
if err != nil {
return err
}
}
}
return nil
}
func (w *watcher) compileLoop(ctx context.Context) error { func (w *watcher) compileLoop(ctx context.Context) error {
firstCompile := true firstCompile := true
for { for {
@ -364,7 +423,8 @@ func (w *watcher) compileLoop(ctx context.Context) error {
w.pw = newPW w.pw = newPW
} }
svg, _, err := compile(ctx, w.ms, w.plugins, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page) fs := trackedFS{}
svg, _, err := compile(ctx, w.ms, w.plugins, &fs, w.layout, w.renderOpts, w.fontFamily, w.animateInterval, w.inputPath, w.outputPath, w.boardPath, w.bundle, w.forceAppendix, w.pw.Page)
errs := "" errs := ""
if err != nil { if err != nil {
if len(svg) > 0 { if len(svg) > 0 {
@ -375,6 +435,11 @@ func (w *watcher) compileLoop(ctx context.Context) error {
errs = err.Error() errs = err.Error()
w.ms.Log.Error.Print(errs) w.ms.Log.Error.Print(errs)
} }
err = w.replaceWatchList(ctx, fs.opened)
if err != nil {
return err
}
w.broadcast(&compileResult{ w.broadcast(&compileResult{
SVG: string(svg), SVG: string(svg),
Scale: w.renderOpts.Scale, Scale: w.renderOpts.Scale,
@ -574,3 +639,13 @@ func wsHeartbeat(ctx context.Context, c *websocket.Conn) {
} }
} }
} }
// trackedFS is OS's FS with the addition that it tracks which files are opened
type trackedFS struct {
opened []string
}
func (tfs *trackedFS) Open(name string) (fs.File, error) {
tfs.opened = append(tfs.opened, name)
return os.Open(name)
}

View file

@ -575,11 +575,8 @@ layers: {
// Wait for watch server to spin up and listen // Wait for watch server to spin up and listen
urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`)
watchURL := waitLogs(ctx, stderr, urlRE) watchURL, err := waitLogs(ctx, stderr, urlRE)
assert.Success(t, err)
if watchURL == "" {
t.Error(errors.New(stderr.String()))
}
stderr.Reset() stderr.Reset()
// Start a client // Start a client
@ -599,8 +596,8 @@ layers: {
assert.Success(t, err) assert.Success(t, err)
successRE := regexp.MustCompile(`broadcasting update to 1 client`) successRE := regexp.MustCompile(`broadcasting update to 1 client`)
line := waitLogs(ctx, stderr, successRE) _, err = waitLogs(ctx, stderr, successRE)
assert.NotEqual(t, "", line) assert.Success(t, err)
}, },
}, },
{ {
@ -631,11 +628,9 @@ layers: {
// Wait for watch server to spin up and listen // Wait for watch server to spin up and listen
urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`)
watchURL := waitLogs(ctx, stderr, urlRE) watchURL, err := waitLogs(ctx, stderr, urlRE)
assert.Success(t, err)
if watchURL == "" {
t.Error(errors.New(stderr.String()))
}
stderr.Reset() stderr.Reset()
// Start a client // Start a client
@ -655,8 +650,8 @@ layers: {
assert.Success(t, err) assert.Success(t, err)
successRE := regexp.MustCompile(`broadcasting update to 1 client`) successRE := regexp.MustCompile(`broadcasting update to 1 client`)
line := waitLogs(ctx, stderr, successRE) _, err = waitLogs(ctx, stderr, successRE)
assert.NotEqual(t, "", line) assert.Success(t, err)
}, },
}, },
{ {
@ -685,11 +680,8 @@ layers: {
// Wait for watch server to spin up and listen // Wait for watch server to spin up and listen
urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`) urlRE := regexp.MustCompile(`127.0.0.1:([0-9]+)`)
watchURL := waitLogs(ctx, stderr, urlRE) watchURL, err := waitLogs(ctx, stderr, urlRE)
assert.Success(t, err)
if watchURL == "" {
t.Error(errors.New(stderr.String()))
}
stderr.Reset() stderr.Reset()
// Start a client // Start a client
@ -709,8 +701,82 @@ layers: {
assert.Success(t, err) assert.Success(t, err)
successRE := regexp.MustCompile(`broadcasting update to 1 client`) successRE := regexp.MustCompile(`broadcasting update to 1 client`)
line := waitLogs(ctx, stderr, successRE) _, err = waitLogs(ctx, stderr, successRE)
assert.NotEqual(t, "", line) assert.Success(t, err)
},
},
{
name: "watch-imported-file",
run: func(t *testing.T, ctx context.Context, dir string, env *xos.Env) {
writeFile(t, dir, "a.d2", `
...@b
`)
writeFile(t, dir, "b.d2", `
x
`)
stderr := &bytes.Buffer{}
tms := testMain(dir, env, "--watch", "--browser=0", "a.d2")
tms.Stderr = stderr
tms.Start(t, ctx)
defer func() {
err := tms.Signal(ctx, os.Interrupt)
assert.Success(t, err)
}()
// Wait for first compilation to finish
doneRE := regexp.MustCompile(`successfully compiled a.d2`)
_, err := waitLogs(ctx, stderr, doneRE)
assert.Success(t, err)
stderr.Reset()
// Test that writing an imported file will cause recompilation
writeFile(t, dir, "b.d2", `
x -> y
`)
bRE := regexp.MustCompile(`detected change in b.d2`)
_, err = waitLogs(ctx, stderr, bRE)
assert.Success(t, err)
stderr.Reset()
// Test burst of both files changing
writeFile(t, dir, "a.d2", `
...@b
hey
`)
writeFile(t, dir, "b.d2", `
x
hi
`)
bothRE := regexp.MustCompile(`detected change in a.d2, b.d2`)
_, err = waitLogs(ctx, stderr, bothRE)
assert.Success(t, err)
// Wait for that compilation to fully finish
_, err = waitLogs(ctx, stderr, doneRE)
assert.Success(t, err)
stderr.Reset()
// Update the main file to no longer have that dependency
writeFile(t, dir, "a.d2", `
a
`)
_, err = waitLogs(ctx, stderr, doneRE)
assert.Success(t, err)
stderr.Reset()
// Change b
writeFile(t, dir, "b.d2", `
y
`)
// Change a to retrigger compilation
// The test works by seeing that the report only says "a" changed, otherwise testing for omission of compilation from "b" would require waiting
writeFile(t, dir, "a.d2", `
c
`)
_, err = waitLogs(ctx, stderr, doneRE)
assert.Success(t, err)
}, },
}, },
} }
@ -810,7 +876,9 @@ func getNumBoards(svg string) int {
return strings.Count(svg, `class="d2`) return strings.Count(svg, `class="d2`)
} }
func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) string { var errRE = regexp.MustCompile(`err:`)
func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) (string, error) {
ticker := time.NewTicker(10 * time.Millisecond) ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop() defer ticker.Stop()
var match string var match string
@ -819,13 +887,20 @@ func waitLogs(ctx context.Context, buf *bytes.Buffer, pattern *regexp.Regexp) st
case <-ticker.C: case <-ticker.C:
out := buf.String() out := buf.String()
match = pattern.FindString(out) match = pattern.FindString(out)
errMatch := errRE.FindString(out)
if errMatch != "" {
return "", errors.New(buf.String())
}
case <-ctx.Done(): case <-ctx.Done():
ticker.Stop() ticker.Stop()
return "" return "", fmt.Errorf("could not match pattern in log. logs: %s", buf.String())
} }
} }
if match == "" {
return "", errors.New(buf.String())
}
return match return match, nil
} }
func getWatchPage(ctx context.Context, t *testing.T, page string) error { func getWatchPage(ctx context.Context, t *testing.T, page string) error {