From 6bbbf2ba0cebb0fa396f24a837b699ae909ef265 Mon Sep 17 00:00:00 2001 From: "Ryan J. Price" Date: Mon, 1 May 2023 00:25:58 -0500 Subject: [PATCH] Move to sync subcommand, use contexts to drive runflags --- main.go | 59 +++++++++++++++++++++++++++++++++--------------- spec.go | 17 +++++--------- spec_test.go | 7 ++---- staticcheck.conf | 5 +++- sync.go | 17 +++++++------- sync_test.go | 7 ++++-- system.go | 7 +++--- system_test.go | 7 ++++-- validate.go | 9 ++++---- validate_test.go | 13 +++++++---- 10 files changed, 89 insertions(+), 59 deletions(-) diff --git a/main.go b/main.go index e3d4d49..e43a08b 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "flag" "fmt" "log" @@ -11,53 +12,75 @@ const ( // ANSI color codes for log messages colorErr = "\033[31m" // red colorDebug = "\033[33m" // yellow + colorInfo = "\033[36m" // cyan colorHappy = "\033[32m" // green colorReset = "\033[0m" ) var ( - // CLI args + // Subcommands + syncCmd = flag.NewFlagSet("sync", flag.ExitOnError) + + subcommands = map[string]*flag.FlagSet{ + syncCmd.Name(): syncCmd, + } + + // CLI args common to each subcommand debug bool specFilePath string // Loggers, which include embedded ANSI color codes - infoLogger = log.New(os.Stderr, fmt.Sprintf("%s[vdm] ", colorReset), 0) + infoLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorInfo, colorReset), 0) errLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorErr, colorReset), 0) debugLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorDebug, colorReset), 0) happyLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorHappy, colorReset), 0) ) -// In case I need to pass these around, so we're not relying on globals -type runFlags struct { - SpecFilePath string - Debug bool +func setCommonFlags() { + for _, cmd := range subcommands { + cmd.StringVar(&specFilePath, "spec-file", "./.vdm", "vdm dependency spec file") + cmd.BoolVar(&debug, "debug", false, "Print debug logs") + } } -func main() { - flag.StringVar(&specFilePath, "spec-file", "./.vdm", "vdm dependency spec file") - flag.BoolVar(&debug, "debug", false, "Print debug logs") - flag.Parse() +func isDebug(ctx context.Context) bool { + debugVal := ctx.Value("debug") + if debugVal == nil { + panic("somehow the debug context key ended up as ") + } - runFlags := runFlags{ - SpecFilePath: specFilePath, - Debug: debug, + return debugVal.(bool) +} + +func main() { + if len(os.Args) == 1 { + errLogger.Fatal("You must provide a command to vdm") } + cmd, ok := subcommands[os.Args[1]] + if !ok { + errLogger.Fatalf("Unrecognized vmd subcommand '%s'", os.Args[1]) + } + setCommonFlags() + cmd.Parse(os.Args[2:]) + + ctx := context.WithValue(context.Background(), "debug", debug) + ctx = context.WithValue(ctx, "specFilePath", specFilePath) - err := checkGitAvailable() + err := checkGitAvailable(ctx) if err != nil { os.Exit(1) } - specs := getSpecsFromFile(specFilePath, runFlags) + specs := getSpecsFromFile(ctx, specFilePath) for _, spec := range specs { - err := spec.Validate() + err := spec.Validate(ctx) if err != nil { - errLogger.Fatalf("your vdm spec file is malformed: %v", err) + errLogger.Fatalf("Your vdm spec file is malformed: %v", err) } } - sync(specs) + sync(ctx, specs) happyLogger.Print("All done!") } diff --git a/spec.go b/spec.go index 7122aa4..e562fa6 100644 --- a/spec.go +++ b/spec.go @@ -1,6 +1,7 @@ package main import ( + "context" "encoding/json" "errors" "os" @@ -11,8 +12,6 @@ type vdmSpec struct { Remote string `json:"remote"` Version string `json:"version"` LocalPath string `json:"local_path"` - // so we can pass these around for each spec - runFlags runFlags `json:"-"` } func (spec vdmSpec) writeVDMMeta() error { @@ -61,13 +60,13 @@ func (spec vdmSpec) getVDMMeta() vdmSpec { return vdmMeta } -func getSpecsFromFile(specFilePath string, runFlags runFlags) []vdmSpec { +func getSpecsFromFile(ctx context.Context, specFilePath string) []vdmSpec { specFile, err := os.ReadFile(specFilePath) if err != nil { - if debug { + if isDebug(ctx) { debugLogger.Printf("error reading specFile from disk: %v", err) } - errLogger.Fatalf("There was a problem reading your vdm file from '%s' -- does it not exist?", specFilePath) + errLogger.Fatalf("There was a problem reading your vdm file from '%s' -- does it not exist? Either pass the -spec-file flag, or create one in the default location (details in the README)", specFilePath) } if debug { debugLogger.Printf("specFile contents read:\n%s", string(specFile)) @@ -76,18 +75,14 @@ func getSpecsFromFile(specFilePath string, runFlags runFlags) []vdmSpec { var specs []vdmSpec err = json.Unmarshal(specFile, &specs) if err != nil { - if debug { + if isDebug(ctx) { debugLogger.Printf("error during specFile unmarshal: %v", err) } errLogger.Fatal("There was a problem reading the contents of your vdm spec file") } - if debug { + if isDebug(ctx) { debugLogger.Printf("vdmSpecs unmarshalled: %+v", specs) } - for _, spec := range specs { - spec.runFlags = runFlags - } - return specs } diff --git a/spec_test.go b/spec_test.go index 6e8ca3e..c77a77a 100644 --- a/spec_test.go +++ b/spec_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "os" "path/filepath" "testing" @@ -56,16 +57,12 @@ func TestSpecGetVDMMeta(t *testing.T) { t.Run("getSpecsFromFile", func(t *testing.T) { specFilePath := "./testdata/.vdm" - runFlags := runFlags{ - SpecFilePath: "./testdata/.vdm", - } - specs := getSpecsFromFile(specFilePath, runFlags) + specs := getSpecsFromFile(context.Background(), specFilePath) assert.Equal(t, 4, len(specs)) t.Cleanup(func() { os.RemoveAll(testVDMMetaFilePath) }) }) - } diff --git a/staticcheck.conf b/staticcheck.conf index 528438b..225deba 100644 --- a/staticcheck.conf +++ b/staticcheck.conf @@ -1 +1,4 @@ -checks = ["all"] +checks = [ + "all", + "-SA1029" # context string keys, for now +] diff --git a/sync.go b/sync.go index 5b7ccd4..5b9f0cc 100644 --- a/sync.go +++ b/sync.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "os/exec" @@ -8,7 +9,7 @@ import ( ) // sync ensures that the only local dependencies are ones defined in the specfile -func sync(specs []vdmSpec) { +func sync(ctx context.Context, specs []vdmSpec) { for _, spec := range specs { // Common log line prefix operationMsg := fmt.Sprintf("%s@%s --> %s", spec.Remote, spec.Version, spec.LocalPath) @@ -21,19 +22,19 @@ func sync(specs []vdmSpec) { if vdmMeta.Version != spec.Version { infoLogger.Printf("Changing '%s' from current local version spec '%s' to '%s'...", spec.Remote, vdmMeta.Version, spec.Version) } else { - if debug { + if isDebug(ctx) { debugLogger.Printf("Version unchanged (%s) in specfile for '%s' --> '%s'", spec.Version, spec.Remote, spec.LocalPath) } } } // TODO: pull this up so that it only runs if the version changed or the user requested a wipe - if debug { + if isDebug(ctx) { debugLogger.Printf("removing any old data for '%s'", spec.LocalPath) } os.RemoveAll(spec.LocalPath) - gitClone(spec, operationMsg) + gitClone(ctx, spec, operationMsg) if spec.Version != "latest" { infoLogger.Printf("%s -- Setting specified version...", operationMsg) @@ -44,7 +45,7 @@ func sync(specs []vdmSpec) { } } - if debug { + if isDebug(ctx) { debugLogger.Printf("removing .git dir for local path '%s'", spec.LocalPath) } os.RemoveAll(filepath.Join(spec.LocalPath, ".git")) @@ -58,18 +59,18 @@ func sync(specs []vdmSpec) { } } -func gitClone(spec vdmSpec, operationMsg string) { +func gitClone(ctx context.Context, spec vdmSpec, operationMsg string) { // If users want "latest", then we can just do a depth-one clone and // skip the checkout operation. But if they want non-latest, we need the // full history to be able to find a specified revision var cloneCmdArgs []string if spec.Version == "latest" { - if debug { + if isDebug(ctx) { debugLogger.Printf("%s -- version specified as 'latest', so making shallow clone and skipping separate checkout operation", operationMsg) } cloneCmdArgs = []string{"clone", "--depth=1", spec.Remote, spec.LocalPath} } else { - if debug { + if isDebug(ctx) { debugLogger.Printf("%s -- version specified as NOT latest, so making regular clone and will make separate checkout operation", operationMsg) } cloneCmdArgs = []string{"clone", spec.Remote, spec.LocalPath} diff --git a/sync_test.go b/sync_test.go index 3c852f4..b94be33 100644 --- a/sync_test.go +++ b/sync_test.go @@ -1,6 +1,7 @@ package main import ( + "context" "path/filepath" "testing" @@ -8,12 +9,14 @@ import ( ) func TestSync(t *testing.T) { + ctx := context.Background() + const testVDMRoot = "./testdata" specFilePath := filepath.Join(testVDMRoot, ".vdm") - specs := getSpecsFromFile(specFilePath, runFlags{SpecFilePath: specFilePath}) + specs := getSpecsFromFile(ctx, specFilePath) - sync(specs) + sync(ctx, specs) t.Run("spec[0] used a tag", func(t *testing.T) { vdmMeta := specs[0].getVDMMeta() diff --git a/system.go b/system.go index b8f75a4..684972a 100644 --- a/system.go +++ b/system.go @@ -1,20 +1,21 @@ package main import ( + "context" "errors" "os/exec" ) -func checkGitAvailable() error { +func checkGitAvailable(ctx context.Context) error { cmd := exec.Command("git", "--version") sysOutput, err := cmd.CombinedOutput() if err != nil { - if debug { + if isDebug(ctx) { debugLogger.Printf("%s: %s", err.Error(), string(sysOutput)) } return errors.New("git does not seem to be available on your PATH, so cannot continue") } - if debug { + if isDebug(ctx) { debugLogger.Print("git was found on PATH") } return nil diff --git a/system_test.go b/system_test.go index 87ebd98..47d3edb 100644 --- a/system_test.go +++ b/system_test.go @@ -1,19 +1,22 @@ package main import ( + "context" "testing" "github.com/stretchr/testify/assert" ) func TestCheckAvailable(t *testing.T) { + ctx := context.Background() + t.Run("git", func(t *testing.T) { // Host of this test better have git available lol - gitAvailable := checkGitAvailable() + gitAvailable := checkGitAvailable(ctx) assert.NoError(t, gitAvailable) t.Setenv("PATH", "") - gitAvailable = checkGitAvailable() + gitAvailable = checkGitAvailable(ctx) assert.Error(t, gitAvailable) }) } diff --git a/validate.go b/validate.go index 0ebff50..6accf84 100644 --- a/validate.go +++ b/validate.go @@ -1,15 +1,16 @@ package main import ( + "context" "errors" "fmt" "regexp" ) -func (spec vdmSpec) Validate() error { +func (spec vdmSpec) Validate(ctx context.Context) error { var allErrors []error - if spec.runFlags.Debug { + if isDebug(ctx) { debugLogger.Printf("validating field 'Remote' for %+v", spec) } if len(spec.Remote) == 0 { @@ -23,14 +24,14 @@ func (spec vdmSpec) Validate() error { ) } - if spec.runFlags.Debug { + if isDebug(ctx) { debugLogger.Printf("validating field 'Version' for %+v", spec) } if len(spec.Version) == 0 { allErrors = append(allErrors, errors.New("all 'version' fields must be non-zero length. If you don't care about the version (even though you should), then use 'latest'")) } - if spec.runFlags.Debug { + if isDebug(ctx) { debugLogger.Printf("validating field 'LocalPath' for %+v", spec) } if len(spec.LocalPath) == 0 { diff --git a/validate_test.go b/validate_test.go index b2f2433..17ba1f3 100644 --- a/validate_test.go +++ b/validate_test.go @@ -1,19 +1,22 @@ package main import ( + "context" "testing" "github.com/stretchr/testify/assert" ) func TestValidate(t *testing.T) { + ctx := context.Background() + t.Run("passes", func(t *testing.T) { spec := vdmSpec{ Remote: "https://some-remote", Version: "v1.0.0", LocalPath: "./deps/some-remote", } - err := spec.Validate() + err := spec.Validate(ctx) assert.NoError(t, err) }) @@ -23,7 +26,7 @@ func TestValidate(t *testing.T) { Version: "v1.0.0", LocalPath: "./deps/some-remote", } - err := spec.Validate() + err := spec.Validate(ctx) assert.Error(t, err) }) @@ -33,7 +36,7 @@ func TestValidate(t *testing.T) { Version: "v1.0.0", LocalPath: "./deps/some-remote", } - err := spec.Validate() + err := spec.Validate(ctx) assert.Error(t, err) }) @@ -43,7 +46,7 @@ func TestValidate(t *testing.T) { Version: "", LocalPath: "./deps/some-remote", } - err := spec.Validate() + err := spec.Validate(ctx) assert.Error(t, err) }) @@ -53,7 +56,7 @@ func TestValidate(t *testing.T) { Version: "v1.0.0", LocalPath: "", } - err := spec.Validate() + err := spec.Validate(ctx) assert.Error(t, err) }) }