diff --git a/cmd/d2/help.go b/cmd/d2/help.go index 961c9c275..3e18a56cf 100644 --- a/cmd/d2/help.go +++ b/cmd/d2/help.go @@ -28,13 +28,13 @@ Subcommands: %[1]s layout [layout name] - Display long help for a particular layout engine See more docs and the source code at https://oss.terrastruct.com/d2 -`, ms.Name, ms.FlagHelp()) +`, ms.Name, ms.Opts.Help()) } func layoutHelp(ctx context.Context, ms *xmain.State) error { - if len(ms.FlagSet.Args()) == 1 { + if len(ms.Opts.Args()) == 1 { return shortLayoutHelp(ctx, ms) - } else if len(ms.FlagSet.Args()) == 2 { + } else if len(ms.Opts.Args()) == 2 { return longLayoutHelp(ctx, ms) } else { return pluginSubcommand(ctx, ms) @@ -61,7 +61,7 @@ func shortLayoutHelp(ctx context.Context, ms *xmain.State) error { %s Usage: - To use a particular layout engine, set the environment variable D2_LAYOUT=[layout name]. + To use a particular layout engine, set the environment variable D2_LAYOUT=[name] or flag --layout=[name]. Example: D2_LAYOUT=dagre d2 in.d2 out.svg @@ -75,7 +75,7 @@ See more docs at https://oss.terrastruct.com/d2 } func longLayoutHelp(ctx context.Context, ms *xmain.State) error { - layout := ms.FlagSet.Arg(1) + layout := ms.Opts.Arg(1) plugin, path, err := d2plugin.FindPlugin(ctx, layout) if errors.Is(err, exec.ErrNotFound) { return layoutNotFound(ctx, layout) @@ -119,13 +119,13 @@ For more information on setup, please visit https://github.com/terrastruct/d2.`, } func pluginSubcommand(ctx context.Context, ms *xmain.State) error { - layout := ms.FlagSet.Arg(1) + layout := ms.Opts.Arg(1) plugin, _, err := d2plugin.FindPlugin(ctx, layout) if errors.Is(err, exec.ErrNotFound) { return layoutNotFound(ctx, layout) } - ms.Args = ms.FlagSet.Args()[2:] + ms.Opts.SetArgs(ms.Opts.Args()[2:]) return d2plugin.Serve(plugin)(ctx, ms) } diff --git a/cmd/d2/main.go b/cmd/d2/main.go index 6ce4d5459..d2f1ecdfd 100644 --- a/cmd/d2/main.go +++ b/cmd/d2/main.go @@ -6,7 +6,6 @@ import ( "fmt" "os/exec" "path/filepath" - "strconv" "strings" "time" @@ -32,19 +31,20 @@ func run(ctx context.Context, ms *xmain.State) (err error) { // :( ctx = xmain.DiscardSlog(ctx) - watchFlag := ms.FlagSet.BoolP("watch", "w", false, "watch for changes to input and live reload. Use $PORT and $HOST to specify the listening address.\n$D2_PORT and $D2_HOST are also accepted and take priority. Default is localhost:0") - themeFlag := ms.FlagSet.Int64P("theme", "t", 0, "set the diagram theme. For a list of available options, see https://oss.terrastruct.com/d2") - bundleFlag := ms.FlagSet.BoolP("bundle", "b", true, "bundle all assets and layers into the output svg") - versionFlag := ms.FlagSet.BoolP("version", "v", false, "get the version and check for updates") - debugFlag := ms.FlagSet.BoolP("debug", "d", false, "print debug logs") - err = ms.FlagSet.Parse(ms.Args) + watchFlag := ms.Opts.Bool("D2_WATCH", "watch", "w", false, "watch for changes to input and live reload. Use $HOST and $PORT to specify the listening address.\n$D2_HOST and $D2_PORT are also accepted and take priority (default localhost:0, which is will open on a randomly available local port).") + bundleFlag := ms.Opts.Bool("D2_BUNDLE", "bundle", "b", true, "bundle all assets and layers into the output svg.") + debugFlag := ms.Opts.Bool("DEBUG", "debug", "d", false, "print debug logs.") + layoutFlag := ms.Opts.String("D2_LAYOUT", "layout", "l", "dagre", `the layout engine used.`) + themeFlag := ms.Opts.Int64("D2_THEME", "theme", "t", 0, "the diagram theme ID. For a list of available options, see https://oss.terrastruct.com/d2") + versionFlag := ms.Opts.Bool("", "version", "v", false, "get the version and check for updates") + err = ms.Opts.Parse() if !errors.Is(err, pflag.ErrHelp) && err != nil { return xmain.UsageErrorf("failed to parse flags: %v", err) } - if len(ms.FlagSet.Args()) > 0 { - switch ms.FlagSet.Arg(0) { + if len(ms.Opts.Args()) > 0 { + switch ms.Opts.Arg(0) { case "layout": return layoutHelp(ctx, ms) } @@ -62,25 +62,25 @@ func run(ctx context.Context, ms *xmain.State) (err error) { var inputPath string var outputPath string - if len(ms.FlagSet.Args()) == 0 { + if len(ms.Opts.Args()) == 0 { if versionFlag != nil && *versionFlag { version.CheckVersion(ctx, ms.Log) return nil } help(ms) return nil - } else if len(ms.FlagSet.Args()) >= 3 { + } else if len(ms.Opts.Args()) >= 3 { return xmain.UsageErrorf("too many arguments passed") } - if len(ms.FlagSet.Args()) >= 1 { - if ms.FlagSet.Arg(0) == "version" { + if len(ms.Opts.Args()) >= 1 { + if ms.Opts.Arg(0) == "version" { version.CheckVersion(ctx, ms.Log) return nil } - inputPath = ms.FlagSet.Arg(0) + inputPath = ms.Opts.Arg(0) } - if len(ms.FlagSet.Args()) >= 2 { - outputPath = ms.FlagSet.Arg(1) + if len(ms.Opts.Args()) >= 2 { + outputPath = ms.Opts.Arg(1) } else { if inputPath == "-" { outputPath = "-" @@ -93,16 +93,11 @@ func run(ctx context.Context, ms *xmain.State) (err error) { if match == (d2themes.Theme{}) { return xmain.UsageErrorf("-t[heme] could not be found. The available options are:\n%s\nYou provided: %d", d2themescatalog.CLIString(), *themeFlag) } - ms.Env.Setenv("D2_THEME", fmt.Sprintf("%d", *themeFlag)) + ms.Log.Debug.Printf("using theme %s (ID: %d)", match.Name, *themeFlag) - envD2Layout := ms.Env.Getenv("D2_LAYOUT") - if envD2Layout == "" { - envD2Layout = "dagre" - } - - plugin, path, err := d2plugin.FindPlugin(ctx, envD2Layout) + plugin, path, err := d2plugin.FindPlugin(ctx, *layoutFlag) if errors.Is(err, exec.ErrNotFound) { - return layoutNotFound(ctx, envD2Layout) + return layoutNotFound(ctx, *layoutFlag) } else if err != nil { return err } @@ -111,14 +106,14 @@ func run(ctx context.Context, ms *xmain.State) (err error) { if path != "" { pluginLocation = fmt.Sprintf("executable plugin at %s", humanPath(path)) } - ms.Log.Debug.Printf("using layout plugin %s (%s)", envD2Layout, pluginLocation) + ms.Log.Debug.Printf("using layout plugin %s (%s)", *layoutFlag, pluginLocation) if *watchFlag { if inputPath == "-" { return xmain.UsageErrorf("-w[atch] cannot be combined with reading input from stdin") } ms.Env.Setenv("LOG_TIMESTAMPS", "1") - w, err := newWatcher(ctx, ms, plugin, inputPath, outputPath) + w, err := newWatcher(ctx, ms, plugin, *themeFlag, inputPath, outputPath) if err != nil { return err } @@ -132,7 +127,7 @@ func run(ctx context.Context, ms *xmain.State) (err error) { _ = 343 } - _, err = compile(ctx, ms, plugin, inputPath, outputPath) + _, err = compile(ctx, ms, plugin, *themeFlag, inputPath, outputPath) if err != nil { return err } @@ -140,7 +135,7 @@ func run(ctx context.Context, ms *xmain.State) (err error) { return nil } -func compile(ctx context.Context, ms *xmain.State, plugin d2plugin.Plugin, inputPath, outputPath string) ([]byte, error) { +func compile(ctx context.Context, ms *xmain.State, plugin d2plugin.Plugin, themeID int64, inputPath, outputPath string) ([]byte, error) { input, err := ms.ReadPath(inputPath) if err != nil { return nil, err @@ -151,7 +146,6 @@ func compile(ctx context.Context, ms *xmain.State, plugin d2plugin.Plugin, input return nil, err } - themeID, _ := strconv.ParseInt(ms.Env.Getenv("D2_THEME"), 10, 64) d, err := d2.Compile(ctx, string(input), &d2.CompileOptions{ Layout: plugin.Layout, Ruler: ruler, diff --git a/cmd/d2/watch.go b/cmd/d2/watch.go index f9570da66..18dd9ea3c 100644 --- a/cmd/d2/watch.go +++ b/cmd/d2/watch.go @@ -42,6 +42,7 @@ type watcher struct { ms *xmain.State layoutPlugin d2plugin.Plugin + themeID int64 inputPath string outputPath string @@ -68,7 +69,7 @@ type compileResult struct { SVG string `json:"svg"` } -func newWatcher(ctx context.Context, ms *xmain.State, layoutPlugin d2plugin.Plugin, inputPath, outputPath string) (*watcher, error) { +func newWatcher(ctx context.Context, ms *xmain.State, layoutPlugin d2plugin.Plugin, themeID int64, inputPath, outputPath string) (*watcher, error) { ctx, cancel := context.WithCancel(ctx) w := &watcher{ @@ -78,6 +79,7 @@ func newWatcher(ctx context.Context, ms *xmain.State, layoutPlugin d2plugin.Plug ms: ms, layoutPlugin: layoutPlugin, + themeID: themeID, inputPath: inputPath, outputPath: outputPath, @@ -325,7 +327,7 @@ func (w *watcher) compileLoop(ctx context.Context) error { recompiledPrefix = "re" } - b, err := compile(ctx, w.ms, w.layoutPlugin, w.inputPath, w.outputPath) + b, err := compile(ctx, w.ms, w.layoutPlugin, w.themeID, w.inputPath, w.outputPath) if err != nil { err = fmt.Errorf("failed to %scompile: %w", recompiledPrefix, err) w.ms.Log.Error.Print(err) diff --git a/d2plugin/serve.go b/d2plugin/serve.go index 919eab48b..a51c48614 100644 --- a/d2plugin/serve.go +++ b/d2plugin/serve.go @@ -19,12 +19,12 @@ import ( // Also see execPlugin in exec.go for the d2 binary plugin protocol. func Serve(p Plugin) func(context.Context, *xmain.State) error { return func(ctx context.Context, ms *xmain.State) (err error) { - if len(ms.Args) < 1 { + if len(ms.Opts.Args()) < 1 { return errors.New("expected first argument to plugin binary to be function name") } - reqFunc := ms.Args[0] + reqFunc := ms.Opts.Arg(0) - switch ms.Args[0] { + switch ms.Opts.Arg(0) { case "info": return info(ctx, p, ms) case "layout": diff --git a/lib/xmain/opts.go b/lib/xmain/opts.go new file mode 100644 index 000000000..c095d58cb --- /dev/null +++ b/lib/xmain/opts.go @@ -0,0 +1,132 @@ +package xmain + +import ( + "fmt" + "io" + "strconv" + "strings" + + "github.com/spf13/pflag" + "oss.terrastruct.com/cmdlog" + "oss.terrastruct.com/xos" +) + +type Opts struct { + args []string + flags *pflag.FlagSet + env *xos.Env + log *cmdlog.Logger + + registeredEnvs []string +} + +func NewOpts(env *xos.Env, args []string, log *cmdlog.Logger) *Opts { + flags := pflag.NewFlagSet("", pflag.ContinueOnError) + flags.SortFlags = false + flags.Usage = func() {} + flags.SetOutput(io.Discard) + return &Opts{ + args: args, + flags: flags, + env: env, + log: log, + } +} + +func (o *Opts) Help() string { + b := &strings.Builder{} + o.flags.SetOutput(b) + o.flags.PrintDefaults() + + if len(o.registeredEnvs) > 0 { + b.WriteString("\nYou may persistently set the following as environment variables (flags take precedent):\n") + for i, e := range o.registeredEnvs { + s := fmt.Sprintf("- $%s", e) + if i != len(o.registeredEnvs)-1 { + s += "\n" + } + b.WriteString(s) + } + } + + return b.String() +} + +func (o *Opts) Int64(envKey, flag, shortFlag string, defaultVal int64, usage string) *int64 { + if envKey != "" { + if o.env.Getenv(envKey) != "" { + envVal, err := strconv.ParseInt(o.env.Getenv(envKey), 10, 64) + if err != nil { + o.log.Error.Printf(`ignoring invalid environment variable %s. Expected int64. Found "%v".`, envKey, envVal) + } else if envVal != defaultVal { + defaultVal = envVal + } + } + o.registeredEnvs = append(o.registeredEnvs, envKey) + } + + return o.flags.Int64P(flag, shortFlag, defaultVal, usage) +} + +func (o *Opts) String(envKey, flag, shortFlag string, defaultVal, usage string) *string { + if envKey != "" { + if o.env.Getenv(envKey) != "" { + envVal := o.env.Getenv(envKey) + if envVal != defaultVal { + defaultVal = envVal + } + } + o.registeredEnvs = append(o.registeredEnvs, envKey) + } + + return o.flags.StringP(flag, shortFlag, defaultVal, usage) +} + +func (o *Opts) Bool(envKey, flag, shortFlag string, defaultVal bool, usage string) *bool { + if envKey != "" { + if o.env.Getenv(envKey) != "" { + envVal := o.env.Getenv(envKey) + if !boolyEnv(envVal) { + o.log.Error.Printf(`ignoring invalid environment variable %s. Expected bool. Found "%s".`, envKey, envVal) + } else if (defaultVal && falseyEnv(envVal)) || + (!defaultVal && truthyEnv(envVal)) { + defaultVal = !defaultVal + } + } + o.registeredEnvs = append(o.registeredEnvs, envKey) + } + + return o.flags.BoolP(flag, shortFlag, defaultVal, usage) +} + +func boolyEnv(s string) bool { + return falseyEnv(s) || truthyEnv(s) +} + +func falseyEnv(s string) bool { + return s == "0" || s == "false" || s == "f" +} + +func truthyEnv(s string) bool { + return s == "1" || s == "true" || s == "t" +} + +func (o *Opts) Parse() error { + err := o.flags.Parse(o.args) + if err != nil { + return err + } + return nil +} + +func (o *Opts) SetArgs(args []string) { + o.args = args +} + +func (o *Opts) Args() []string { + return o.flags.Args() +} + +func (o *Opts) Arg(i int) string { + return o.flags.Arg(i) +} diff --git a/lib/xmain/xmain.go b/lib/xmain/xmain.go index d71147776..3ef9b89aa 100644 --- a/lib/xmain/xmain.go +++ b/lib/xmain/xmain.go @@ -9,13 +9,11 @@ import ( "io" "os" "os/signal" - "strings" "syscall" "time" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/spf13/pflag" "oss.terrastruct.com/xos" @@ -41,14 +39,10 @@ func Main(run RunFunc) { Stdout: os.Stdout, Stderr: os.Stderr, - Env: xos.NewEnv(os.Environ()), - FlagSet: pflag.NewFlagSet("", pflag.ContinueOnError), - Args: args, + Env: xos.NewEnv(os.Environ()), } ms.Log = cmdlog.Log(ms.Env, os.Stderr) - ms.FlagSet.SortFlags = false - ms.FlagSet.Usage = func() {} - ms.FlagSet.SetOutput(io.Discard) + ms.Opts = NewOpts(ms.Env, args, ms.Log) sigs := make(chan os.Signal, 1) signal.Notify(sigs, os.Interrupt, syscall.SIGTERM) @@ -88,10 +82,9 @@ type State struct { Stdout io.WriteCloser Stderr io.WriteCloser - Log *cmdlog.Logger - Env *xos.Env - Args []string - FlagSet *pflag.FlagSet + Log *cmdlog.Logger + Env *xos.Env + Opts *Opts } func (ms *State) Main(ctx context.Context, sigs <-chan os.Signal, run func(context.Context, *State) error) error { @@ -129,13 +122,6 @@ func (ms *State) Main(ctx context.Context, sigs <-chan os.Signal, run func(conte } } -func (ms *State) FlagHelp() string { - b := &strings.Builder{} - ms.FlagSet.SetOutput(b) - ms.FlagSet.PrintDefaults() - return b.String() -} - type ExitError struct { Code int `json:"code"` Message string `json:"message"`