Skip to content

Commit

Permalink
refactor: break main into functions
Browse files Browse the repository at this point in the history
  • Loading branch information
abuchanan-airbyte committed Sep 19, 2024
1 parent 9b84233 commit 0b8a47f
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 68 deletions.
22 changes: 0 additions & 22 deletions internal/cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cmd

import (
"errors"
"fmt"
"os"

Expand All @@ -16,27 +15,6 @@ import (
"k8s.io/client-go/tools/clientcmd"
)

func HandleErr(err error) {
if err == nil {
return
}

pterm.Error.Println(err)

var errParse *kong.ParseError
if errors.As(err, &errParse) {
_ = kong.DefaultHelpPrinter(kong.HelpOptions{}, errParse.Context)
}

var e *localerr.LocalError
if errors.As(err, &e) {
pterm.Println()
pterm.Info.Println(e.Help())
}

os.Exit(1)
}

type verbose bool

func (v verbose) BeforeApply() error {
Expand Down
10 changes: 9 additions & 1 deletion internal/update/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (
"errors"
"fmt"
"net/http"
"time"

"github.com/airbytehq/abctl/internal/build"
"golang.org/x/mod/semver"
)

Expand All @@ -20,7 +22,13 @@ type doer interface {
// This is accomplished by fetching the latest github tag and comparing it to the version provided.
// Returns the latest version, or an empty string if we're already running the latest version.
// Will return ErrDevVersion if the build.Version is currently set to "dev".
func Check(ctx context.Context, doer doer, version string) (string, error) {
func Check() (string, error) {
ctx, updateCancel := context.WithTimeout(context.Background(), 2*time.Second)
defer updateCancel()
return check(ctx, http.DefaultClient, build.Version)
}

func check(ctx context.Context, doer doer, version string) (string, error) {
if version == "dev" {
return "", ErrDevVersion
}
Expand Down
6 changes: 3 additions & 3 deletions internal/update/update_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func TestCheck(t *testing.T) {
},
}

latest, err := Check(ctx, h, tt.local)
latest, err := check(ctx, h, tt.local)
if d := cmp.Diff(tt.wantErr, err, cmpopts.EquateErrors()); d != "" {
t.Errorf("unexpected error: %s", err)
}
Expand All @@ -82,7 +82,7 @@ func TestCheck_HTTPRequest(t *testing.T) {
},
}

if _, err := Check(context.Background(), h, "v0.1.0"); err != nil {
if _, err := check(context.Background(), h, "v0.1.0"); err != nil {
t.Error("unexpected error:", err)
}
// verify method
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestCheck_HTTPErr(t *testing.T) {
},
}

_, err := Check(context.Background(), h, "v0.1.0")
_, err := check(context.Background(), h, "v0.1.0")
if err == nil {
t.Error("unexpected success")
}
Expand Down
108 changes: 66 additions & 42 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,45 +3,29 @@ package main
import (
"context"
"errors"
"net/http"
"os"
"os/signal"
"syscall"
"time"

"github.com/airbytehq/abctl/internal/build"
"github.com/airbytehq/abctl/internal/cmd"
"github.com/airbytehq/abctl/internal/cmd/local/localerr"
"github.com/airbytehq/abctl/internal/update"
"github.com/alecthomas/kong"
"github.com/pterm/pterm"
)

func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

// check for update
updateCtx, updateCancel := context.WithTimeout(ctx, 2*time.Second)
defer updateCancel()

updateChan := make(chan updateInfo)
go func() {
info := updateInfo{}
info.version, info.err = update.Check(updateCtx, http.DefaultClient, build.Version)
updateChan <- info
}()

// listen for shutdown signals
go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh

cancel()
}()

// ensure the pterm info width matches the other printers
pterm.Info.Prefix.Text = " INFO "
printUpdateMsg := checkForNewerAbctlVersion()
handleErr(run())
printUpdateMsg()
}

func run() error {
ctx, cancel := cliContext()
defer cancel()

var root cmd.Cmd
parser, err := kong.New(
Expand All @@ -51,37 +35,77 @@ func main() {
kong.UsageOnError(),
)
if err != nil {
cmd.HandleErr(err)
return err
}
parsed, err := parser.Parse(os.Args[1:])
if err != nil {
cmd.HandleErr(err)
return err
}
if err := parsed.BindToProvider(bindCtx(ctx)); err != nil {
cmd.HandleErr(err)
parsed.BindToProvider(bindCtx(ctx))
return parsed.Run()
}

func handleErr(err error) {
if err == nil {
return
}

cmd.HandleErr(parsed.Run())
pterm.Error.Println(err)

newRelease := <-updateChan
if newRelease.err != nil {
if errors.Is(newRelease.err, update.ErrDevVersion) {
pterm.Debug.Println("Release checking is disabled for dev builds")
}
} else if newRelease.version != "" {
var errParse *kong.ParseError
if errors.As(err, &errParse) {
_ = kong.DefaultHelpPrinter(kong.HelpOptions{}, errParse.Context)
}

var e *localerr.LocalError
if errors.As(err, &e) {
pterm.Println()
pterm.Info.Printfln("A new release of abctl is available: %s -> %s\nUpdating to the latest version is highly recommended", build.Version, newRelease.version)
pterm.Info.Println(e.Help())
}

os.Exit(1)
}

// checks for a newer version of abctl.
// returns a function that, when called, will print the message about the new version.
func checkForNewerAbctlVersion() func() {
c := make(chan string)
go func() {
defer close(c)
ver, err := update.Check()
if err != nil {
pterm.Debug.Printfln("update check: %s", err)
} else {
c <- ver
}
}()

return func() {
ver := <-c
if ver != "" {
pterm.Info.Printfln("A new release of abctl is available: %s -> %s\nUpdating to the latest version is highly recommended", build.Version, ver)

}
}
}

// get a context that listens for interrupt/shutdown signals.
func cliContext() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
// listen for shutdown signals
go func() {
signalCh := make(chan os.Signal, 1)
signal.Notify(signalCh, os.Interrupt, syscall.SIGTERM)
<-signalCh

cancel()
}()
return ctx, cancel
}

// bindCtx exists to allow kong to correctly inject a context.Context into the Run methods on the commands.
func bindCtx(ctx context.Context) func() (context.Context, error) {
return func() (context.Context, error) {
return ctx, nil
}
}

type updateInfo struct {
version string
err error
}

0 comments on commit 0b8a47f

Please sign in to comment.