Skip to content

Commit

Permalink
Move to sync subcommand, use contexts to drive runflags
Browse files Browse the repository at this point in the history
  • Loading branch information
ryapric committed May 1, 2023
1 parent 3be57f7 commit 6bbbf2b
Show file tree
Hide file tree
Showing 10 changed files with 89 additions and 59 deletions.
59 changes: 41 additions & 18 deletions main.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"flag"
"fmt"
"log"
Expand All @@ -11,53 +12,75 @@ const (
// ANSI color codes for log messages
colorErr = "\033[31m" // red
colorDebug = "\033[33m" // yellow
colorInfo = "\033[36m" // cyan
colorHappy = "\033[32m" // green
colorReset = "\033[0m"
)

var (
// CLI args
// Subcommands
syncCmd = flag.NewFlagSet("sync", flag.ExitOnError)

subcommands = map[string]*flag.FlagSet{
syncCmd.Name(): syncCmd,
}

// CLI args common to each subcommand
debug bool
specFilePath string

// Loggers, which include embedded ANSI color codes
infoLogger = log.New(os.Stderr, fmt.Sprintf("%s[vdm] ", colorReset), 0)
infoLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorInfo, colorReset), 0)
errLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorErr, colorReset), 0)
debugLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorDebug, colorReset), 0)
happyLogger = log.New(os.Stderr, fmt.Sprintf("%s%s[vdm]%s ", colorReset, colorHappy, colorReset), 0)
)

// In case I need to pass these around, so we're not relying on globals
type runFlags struct {
SpecFilePath string
Debug bool
func setCommonFlags() {
for _, cmd := range subcommands {
cmd.StringVar(&specFilePath, "spec-file", "./.vdm", "vdm dependency spec file")
cmd.BoolVar(&debug, "debug", false, "Print debug logs")
}
}

func main() {
flag.StringVar(&specFilePath, "spec-file", "./.vdm", "vdm dependency spec file")
flag.BoolVar(&debug, "debug", false, "Print debug logs")
flag.Parse()
func isDebug(ctx context.Context) bool {
debugVal := ctx.Value("debug")
if debugVal == nil {
panic("somehow the debug context key ended up as <nil>")
}

runFlags := runFlags{
SpecFilePath: specFilePath,
Debug: debug,
return debugVal.(bool)
}

func main() {
if len(os.Args) == 1 {
errLogger.Fatal("You must provide a command to vdm")
}
cmd, ok := subcommands[os.Args[1]]
if !ok {
errLogger.Fatalf("Unrecognized vmd subcommand '%s'", os.Args[1])
}
setCommonFlags()
cmd.Parse(os.Args[2:])

ctx := context.WithValue(context.Background(), "debug", debug)
ctx = context.WithValue(ctx, "specFilePath", specFilePath)

err := checkGitAvailable()
err := checkGitAvailable(ctx)
if err != nil {
os.Exit(1)
}

specs := getSpecsFromFile(specFilePath, runFlags)
specs := getSpecsFromFile(ctx, specFilePath)

for _, spec := range specs {
err := spec.Validate()
err := spec.Validate(ctx)
if err != nil {
errLogger.Fatalf("your vdm spec file is malformed: %v", err)
errLogger.Fatalf("Your vdm spec file is malformed: %v", err)
}
}

sync(specs)
sync(ctx, specs)

happyLogger.Print("All done!")
}
17 changes: 6 additions & 11 deletions spec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"encoding/json"
"errors"
"os"
Expand All @@ -11,8 +12,6 @@ type vdmSpec struct {
Remote string `json:"remote"`
Version string `json:"version"`
LocalPath string `json:"local_path"`
// so we can pass these around for each spec
runFlags runFlags `json:"-"`
}

func (spec vdmSpec) writeVDMMeta() error {
Expand Down Expand Up @@ -61,13 +60,13 @@ func (spec vdmSpec) getVDMMeta() vdmSpec {
return vdmMeta
}

func getSpecsFromFile(specFilePath string, runFlags runFlags) []vdmSpec {
func getSpecsFromFile(ctx context.Context, specFilePath string) []vdmSpec {
specFile, err := os.ReadFile(specFilePath)
if err != nil {
if debug {
if isDebug(ctx) {
debugLogger.Printf("error reading specFile from disk: %v", err)
}
errLogger.Fatalf("There was a problem reading your vdm file from '%s' -- does it not exist?", specFilePath)
errLogger.Fatalf("There was a problem reading your vdm file from '%s' -- does it not exist? Either pass the -spec-file flag, or create one in the default location (details in the README)", specFilePath)
}
if debug {
debugLogger.Printf("specFile contents read:\n%s", string(specFile))
Expand All @@ -76,18 +75,14 @@ func getSpecsFromFile(specFilePath string, runFlags runFlags) []vdmSpec {
var specs []vdmSpec
err = json.Unmarshal(specFile, &specs)
if err != nil {
if debug {
if isDebug(ctx) {
debugLogger.Printf("error during specFile unmarshal: %v", err)
}
errLogger.Fatal("There was a problem reading the contents of your vdm spec file")
}
if debug {
if isDebug(ctx) {
debugLogger.Printf("vdmSpecs unmarshalled: %+v", specs)
}

for _, spec := range specs {
spec.runFlags = runFlags
}

return specs
}
7 changes: 2 additions & 5 deletions spec_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"context"
"os"
"path/filepath"
"testing"
Expand Down Expand Up @@ -56,16 +57,12 @@ func TestSpecGetVDMMeta(t *testing.T) {

t.Run("getSpecsFromFile", func(t *testing.T) {
specFilePath := "./testdata/.vdm"
runFlags := runFlags{
SpecFilePath: "./testdata/.vdm",
}

specs := getSpecsFromFile(specFilePath, runFlags)
specs := getSpecsFromFile(context.Background(), specFilePath)
assert.Equal(t, 4, len(specs))

t.Cleanup(func() {
os.RemoveAll(testVDMMetaFilePath)
})
})

}
5 changes: 4 additions & 1 deletion staticcheck.conf
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
checks = ["all"]
checks = [
"all",
"-SA1029" # context string keys, for now
]
17 changes: 9 additions & 8 deletions sync.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
package main

import (
"context"
"fmt"
"os"
"os/exec"
"path/filepath"
)

// sync ensures that the only local dependencies are ones defined in the specfile
func sync(specs []vdmSpec) {
func sync(ctx context.Context, specs []vdmSpec) {
for _, spec := range specs {
// Common log line prefix
operationMsg := fmt.Sprintf("%s@%s --> %s", spec.Remote, spec.Version, spec.LocalPath)
Expand All @@ -21,19 +22,19 @@ func sync(specs []vdmSpec) {
if vdmMeta.Version != spec.Version {
infoLogger.Printf("Changing '%s' from current local version spec '%s' to '%s'...", spec.Remote, vdmMeta.Version, spec.Version)
} else {
if debug {
if isDebug(ctx) {
debugLogger.Printf("Version unchanged (%s) in specfile for '%s' --> '%s'", spec.Version, spec.Remote, spec.LocalPath)
}
}
}

// TODO: pull this up so that it only runs if the version changed or the user requested a wipe
if debug {
if isDebug(ctx) {
debugLogger.Printf("removing any old data for '%s'", spec.LocalPath)
}
os.RemoveAll(spec.LocalPath)

gitClone(spec, operationMsg)
gitClone(ctx, spec, operationMsg)

if spec.Version != "latest" {
infoLogger.Printf("%s -- Setting specified version...", operationMsg)
Expand All @@ -44,7 +45,7 @@ func sync(specs []vdmSpec) {
}
}

if debug {
if isDebug(ctx) {
debugLogger.Printf("removing .git dir for local path '%s'", spec.LocalPath)
}
os.RemoveAll(filepath.Join(spec.LocalPath, ".git"))
Expand All @@ -58,18 +59,18 @@ func sync(specs []vdmSpec) {
}
}

func gitClone(spec vdmSpec, operationMsg string) {
func gitClone(ctx context.Context, spec vdmSpec, operationMsg string) {
// If users want "latest", then we can just do a depth-one clone and
// skip the checkout operation. But if they want non-latest, we need the
// full history to be able to find a specified revision
var cloneCmdArgs []string
if spec.Version == "latest" {
if debug {
if isDebug(ctx) {
debugLogger.Printf("%s -- version specified as 'latest', so making shallow clone and skipping separate checkout operation", operationMsg)
}
cloneCmdArgs = []string{"clone", "--depth=1", spec.Remote, spec.LocalPath}
} else {
if debug {
if isDebug(ctx) {
debugLogger.Printf("%s -- version specified as NOT latest, so making regular clone and will make separate checkout operation", operationMsg)
}
cloneCmdArgs = []string{"clone", spec.Remote, spec.LocalPath}
Expand Down
7 changes: 5 additions & 2 deletions sync_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package main

import (
"context"
"path/filepath"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSync(t *testing.T) {
ctx := context.Background()

const testVDMRoot = "./testdata"
specFilePath := filepath.Join(testVDMRoot, ".vdm")

specs := getSpecsFromFile(specFilePath, runFlags{SpecFilePath: specFilePath})
specs := getSpecsFromFile(ctx, specFilePath)

sync(specs)
sync(ctx, specs)

t.Run("spec[0] used a tag", func(t *testing.T) {
vdmMeta := specs[0].getVDMMeta()
Expand Down
7 changes: 4 additions & 3 deletions system.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
package main

import (
"context"
"errors"
"os/exec"
)

func checkGitAvailable() error {
func checkGitAvailable(ctx context.Context) error {
cmd := exec.Command("git", "--version")
sysOutput, err := cmd.CombinedOutput()
if err != nil {
if debug {
if isDebug(ctx) {
debugLogger.Printf("%s: %s", err.Error(), string(sysOutput))
}
return errors.New("git does not seem to be available on your PATH, so cannot continue")
}
if debug {
if isDebug(ctx) {
debugLogger.Print("git was found on PATH")
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions system_test.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
package main

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestCheckAvailable(t *testing.T) {
ctx := context.Background()

t.Run("git", func(t *testing.T) {
// Host of this test better have git available lol
gitAvailable := checkGitAvailable()
gitAvailable := checkGitAvailable(ctx)
assert.NoError(t, gitAvailable)

t.Setenv("PATH", "")
gitAvailable = checkGitAvailable()
gitAvailable = checkGitAvailable(ctx)
assert.Error(t, gitAvailable)
})
}
9 changes: 5 additions & 4 deletions validate.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
package main

import (
"context"
"errors"
"fmt"
"regexp"
)

func (spec vdmSpec) Validate() error {
func (spec vdmSpec) Validate(ctx context.Context) error {
var allErrors []error

if spec.runFlags.Debug {
if isDebug(ctx) {
debugLogger.Printf("validating field 'Remote' for %+v", spec)
}
if len(spec.Remote) == 0 {
Expand All @@ -23,14 +24,14 @@ func (spec vdmSpec) Validate() error {
)
}

if spec.runFlags.Debug {
if isDebug(ctx) {
debugLogger.Printf("validating field 'Version' for %+v", spec)
}
if len(spec.Version) == 0 {
allErrors = append(allErrors, errors.New("all 'version' fields must be non-zero length. If you don't care about the version (even though you should), then use 'latest'"))
}

if spec.runFlags.Debug {
if isDebug(ctx) {
debugLogger.Printf("validating field 'LocalPath' for %+v", spec)
}
if len(spec.LocalPath) == 0 {
Expand Down
Loading

0 comments on commit 6bbbf2b

Please sign in to comment.