diff --git a/maps/mapsutil.go b/maps/mapsutil.go index c719b46..2ece625 100644 --- a/maps/mapsutil.go +++ b/maps/mapsutil.go @@ -10,7 +10,6 @@ import ( "time" "github.com/miekg/dns" - "golang.org/x/exp/maps" extmaps "golang.org/x/exp/maps" ) @@ -229,7 +228,7 @@ func Walk(m map[string]any, callback func(k string, v any)) { // Clear the map passed as parameter func Clear[K comparable, V any](mm ...map[K]V) { for _, m := range mm { - maps.Clear(m) + extmaps.Clear(m) } } diff --git a/update/gh.go b/update/gh.go index cf00228..8b6dbdd 100644 --- a/update/gh.go +++ b/update/gh.go @@ -6,8 +6,10 @@ import ( "bytes" "compress/gzip" "context" - "fmt" + "crypto/sha256" + "encoding/hex" "io" + "io/fs" "net/http" "os" "runtime" @@ -15,61 +17,252 @@ import ( "github.com/cheggaaa/pb/v3" "github.com/google/go-github/v30/github" + "github.com/projectdiscovery/gologger" errorutil "github.com/projectdiscovery/utils/errors" "golang.org/x/oauth2" ) var ( - extIfFound = ".exe" - ErrNoAssetFound = errorutil.NewWithFmt("update: could not find release asset for your platform (%s/%s)") - GHAssetName = "" + extIfFound = ".exe" + ErrNoAssetFound = errorutil.NewWithFmt("update: could not find release asset for your platform (%s/%s)") + SkipCheckSumValidation = false // by default checksum of gh assets is verified with checksums file present in release ) +// AssetFileCallback function is executed on every file in unpacked asset . if returned error +// is not nil furthur processing of asset file is stopped +type AssetFileCallback func(path string, fileInfo fs.FileInfo, data io.Reader) error + // GHReleaseDownloader fetches and reads release of a gh repo type GHReleaseDownloader struct { - ToolName string // we assume toolname and ToolName are always same - Format AssetFormat - AssetID int - AssetName string - client *github.Client - httpClient *http.Client + assetName string // required assetName given as input + repoName string // we assume toolname and repoName are always same + fullAssetName string // full asset name of asset that contains tool for this platform + organization string // organization name of repo + Format AssetFormat + AssetID int + Latest *github.RepositoryRelease + client *github.Client + httpClient *http.Client } -// NewghReleaseDownloader instance -func NewghReleaseDownloader(toolName string) *GHReleaseDownloader { +// NewghReleaseDownloader returns GHRD instance +func NewghReleaseDownloader(RepoName string) (*GHReleaseDownloader, error) { + var orgName, repoName string + if strings.Contains(RepoName, "/") { + arr := strings.Split(RepoName, "/") + if len(arr) != 2 { + return nil, errorutil.NewWithTag("update", "invalid repo name %v", RepoName) + } + orgName = arr[0] + repoName = arr[1] + } else { + orgName = Organization + repoName = RepoName + } httpClient := &http.Client{ Timeout: DownloadUpdateTimeout, } + if orgName == "" { + return nil, errorutil.NewWithTag("update", "organization name cannot be empty") + } if token := os.Getenv("GITHUB_TOKEN"); token != "" { httpClient = oauth2.NewClient(context.Background(), oauth2.StaticTokenSource(&oauth2.Token{AccessToken: token})) } - ghrd := GHReleaseDownloader{client: github.NewClient(httpClient), ToolName: toolName, httpClient: httpClient} + ghrd := GHReleaseDownloader{client: github.NewClient(httpClient), repoName: repoName, assetName: repoName, httpClient: httpClient, organization: orgName} + + err := ghrd.getLatestRelease() + return &ghrd, err +} + +// SetAssetName: By default RepoName is assumed as ToolName which maynot be the case always setToolName corrects that +func (d *GHReleaseDownloader) SetToolName(toolName string) { + if toolName != "" { + d.assetName = toolName + } +} + +// DownloadTool downloads tool and returns bin data +func (d *GHReleaseDownloader) DownloadTool() (*bytes.Buffer, error) { + if err := d.getToolAssetID(d.Latest); err != nil { + return nil, err + } + resp, err := d.downloadAssetwithID(int64(d.AssetID)) + if err != nil { + return nil, err + } + defer resp.Body.Close() - if ghrd.AssetName == "" && GHAssetName != "" { - ghrd.AssetName = GHAssetName + if !HideProgressBar { + bar := pb.New64(resp.ContentLength).SetMaxWidth(100) + bar.Start() + resp.Body = bar.NewProxyReader(resp.Body) + defer bar.Finish() } - if ghrd.AssetName == "" { - ghrd.AssetName = ghrd.ToolName + + bin, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errorutil.NewWithErr(err).Msgf("failed to read response body") } - return &ghrd + return bytes.NewBuffer(bin), nil } -// getLatestRelease returns latest release of error -func (d *GHReleaseDownloader) GetLatestRelease() (*github.RepositoryRelease, error) { - release, resp, err := d.client.Repositories.GetLatestRelease(context.Background(), Organization, d.ToolName) +// GetReleaseChecksums tries to download tool checksum if release contains any in map[asset_name]checksum_data format +func (d *GHReleaseDownloader) GetReleaseChecksums() (map[string]string, error) { + builder := &strings.Builder{} + builder.WriteString(d.assetName) + builder.WriteString("_") + builder.WriteString(strings.TrimPrefix(d.Latest.GetTagName(), "v")) + builder.WriteString("_") + builder.WriteString("checksums.txt") + checksumFileName := builder.String() + + checksumFileAssetID := 0 + for _, v := range d.Latest.Assets { + if v.GetName() == checksumFileName { + checksumFileAssetID = int(v.GetID()) + } + } + if checksumFileAssetID == 0 { + return nil, errorutil.NewWithTag("update", "checksum file not in release assets") + } + + resp, err := d.downloadAssetwithID(int64(checksumFileAssetID)) if err != nil { - if resp != nil && resp.StatusCode == 404 { - return nil, fmt.Errorf("repo %v/%v not found got %v", Organization, d.ToolName, err) + return nil, errorutil.NewWithErr(err).Msgf("failed to download checksum file") + } + defer resp.Body.Close() + bin, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errorutil.NewWithErr(err).Msgf("failed to read checksum file") + } + data := strings.TrimSpace(string(bin)) + if data == "" { + return nil, errorutil.NewWithTag("checksum", "something went wrong checksum file is emtpy") + } + m := map[string]string{} + for _, v := range strings.Split(data, "\n") { + arr := strings.Fields(v) + if len(arr) != 2 { + continue } + m[arr[1]] = arr[0] + } + return m, nil +} + +// GetExecutableFromAsset downloads , validates checksum and only returns tool Binary +func (d *GHReleaseDownloader) GetExecutableFromAsset() ([]byte, error) { + var bin []byte + var err error + getToolCallback := func(path string, fileInfo fs.FileInfo, data io.Reader) error { + if !strings.EqualFold(strings.TrimSuffix(fileInfo.Name(), extIfFound), d.assetName) { + return nil + } + bin, err = io.ReadAll(data) + return err + } + + buff, err := d.DownloadTool() + if err != nil { return nil, err } - return release, nil + + var expectedChecksum string + checksums, err := d.GetReleaseChecksums() + if checksums != nil { + expectedChecksum = checksums[d.fullAssetName] + } + // verify integrity using checksum + if expectedChecksum != "" { + gotChecksumbytes := sha256.Sum256(buff.Bytes()) + gotchecksum := hex.EncodeToString(gotChecksumbytes[:]) + if expectedChecksum != gotchecksum { + return nil, errorutil.NewWithTag("checksum", "asset file corrupted: checksum mismatch expected %v but got %v", expectedChecksum, gotchecksum) + } else { + gologger.Info().Msgf("Verified Integrity of %v", d.fullAssetName) + } + } + + _ = UnpackAssetWithCallback(d.Format, bytes.NewReader(buff.Bytes()), getToolCallback) + return bin, errorutil.WrapfWithNil(err, "executable not found in archive") // Note: WrapfWithNil wraps msg if err != nil } -// getAssetIDFromRelease finds AssetID from release or returns a descriptive error -func (d *GHReleaseDownloader) GetAssetIDFromRelease(latest *github.RepositoryRelease) error { +// DownloadAssetWithName downloads asset with given name +func (d *GHReleaseDownloader) DownloadAssetWithName(assetname string, showProgressBar bool) (*bytes.Buffer, error) { + assetID := 0 + for _, v := range d.Latest.Assets { + if v.GetName() == assetname { + assetID = int(v.GetID()) + } + } + if assetID == 0 { + return nil, errorutil.New("release asset %v not found", assetname) + } + resp, err := d.downloadAssetwithID(int64(assetID)) + if err != nil { + return nil, errorutil.NewWithErr(err).Msgf("failed to download asset %v", assetname) + } + defer resp.Body.Close() + + if showProgressBar { + bar := pb.New64(resp.ContentLength).SetMaxWidth(100) + bar.Start() + resp.Body = bar.NewProxyReader(resp.Body) + defer bar.Finish() + } + + bin, err := io.ReadAll(resp.Body) + if err != nil { + return nil, errorutil.NewWithErr(err).Msgf("failed to read resp body") + } + return bytes.NewBuffer(bin), nil +} + +// DownloadSourceWithCallback downloads source code of latest release and calls callback for each file in archive +func (d *GHReleaseDownloader) DownloadSourceWithCallback(showProgressBar bool, callback AssetFileCallback) error { + downloadURL := d.Latest.GetZipballURL() + + resp, err := d.httpClient.Get(downloadURL) + if err != nil { + return errorutil.NewWithErr(err).Msgf("failed to source of %v", d.repoName) + } + defer resp.Body.Close() + if showProgressBar { + bar := pb.New64(resp.ContentLength).SetMaxWidth(100) + bar.Start() + resp.Body = bar.NewProxyReader(resp.Body) + defer bar.Finish() + } + + bin, err := io.ReadAll(resp.Body) + if err != nil { + return errorutil.NewWithErr(err).Msgf("failed to read resp body") + } + return UnpackAssetWithCallback(Zip, bytes.NewReader(bin), callback) +} + +// getLatestRelease returns latest release of error +func (d *GHReleaseDownloader) getLatestRelease() error { + release, resp, err := d.client.Repositories.GetLatestRelease(context.Background(), d.organization, d.repoName) + if err != nil { + errx := errorutil.NewWithErr(err) + if resp != nil && resp.StatusCode == http.StatusNotFound { + errx = errx.Msgf("repo %v/%v not found got %v", d.organization, d.repoName) + } else if _, ok := err.(*github.RateLimitError); ok { + errx = errx.Msgf("hit github ratelimit while downloading latest release") + } else if resp != nil && (resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusUnauthorized) { + errx = errx.Msgf("gh auth failed try unsetting GITHUB_TOKEN env variable") + } + return errx + } + d.Latest = release + return nil +} + +// getToolAssetID tries to find assetId of tool required for this platform +func (d *GHReleaseDownloader) getToolAssetID(latest *github.RepositoryRelease) error { builder := &strings.Builder{} - builder.WriteString(d.AssetName) + builder.WriteString(d.assetName) builder.WriteString("_") builder.WriteString(strings.TrimPrefix(latest.GetTagName(), "v")) builder.WriteString("_") @@ -83,18 +276,20 @@ func (d *GHReleaseDownloader) GetAssetIDFromRelease(latest *github.RepositoryRel loop: for _, v := range latest.Assets { - asset := *v.Name + asset := v.GetName() switch { - case strings.Contains(asset, ".zip"): - if strings.EqualFold(asset, builder.String()+".zip") { - d.AssetID = int(*v.ID) + case strings.Contains(asset, Zip.FileExtension()): + if strings.EqualFold(asset, builder.String()+Zip.FileExtension()) { + d.AssetID = int(v.GetID()) d.Format = Zip + d.fullAssetName = asset break loop } - case strings.Contains(asset, ".tar.gz"): - if strings.EqualFold(asset, builder.String()+".tar.gz") { - d.AssetID = int(*v.ID) + case strings.Contains(asset, Tar.FileExtension()): + if strings.EqualFold(asset, builder.String()+Tar.FileExtension()) { + d.AssetID = int(v.GetID()) d.Format = Tar + d.fullAssetName = asset break loop } } @@ -108,9 +303,9 @@ loop: return nil } -// DownloadAssetFromID downloads and returns a buffer or a descriptive error -func (d *GHReleaseDownloader) DownloadAssetFromID() (*bytes.Buffer, error) { - _, rdurl, err := d.client.Repositories.DownloadReleaseAsset(context.Background(), Organization, d.ToolName, int64(d.AssetID), nil) +// downloadAssetwithID +func (d *GHReleaseDownloader) downloadAssetwithID(id int64) (*http.Response, error) { + _, rdurl, err := d.client.Repositories.DownloadReleaseAsset(context.Background(), d.organization, d.repoName, id, nil) if err != nil { return nil, err } @@ -124,52 +319,33 @@ func (d *GHReleaseDownloader) DownloadAssetFromID() (*bytes.Buffer, error) { if resp.Body == nil { return nil, errorutil.New("something went wrong got response without body") } - defer resp.Body.Close() - - if !HideProgressBar { - bar := pb.New64(resp.ContentLength).SetMaxWidth(100) - bar.Start() - resp.Body = bar.NewProxyReader(resp.Body) - defer bar.Finish() - } - - bin, err := io.ReadAll(resp.Body) - if err != nil { - return nil, errorutil.NewWithErr(err).Msgf("failed to read response body") - } - return bytes.NewBuffer(bin), nil + return resp, nil } -// GetExecutableFromAsset downloads and only returns tool Binary -func (d *GHReleaseDownloader) GetExecutableFromAsset() ([]byte, error) { - buff, err := d.DownloadAssetFromID() - if err != nil { - return nil, err +// UnpackAssetWithCallback unpacks asset and executes callback function on every file in data +func UnpackAssetWithCallback(format AssetFormat, data *bytes.Reader, callback AssetFileCallback) error { + if format != Zip && format != Tar { + return errorutil.NewWithTag("unpack", "github asset format not supported. only zip and tar are supported") } - if d.Format == Zip { - zipReader, err := zip.NewReader(bytes.NewReader(buff.Bytes()), int64(buff.Len())) + if format == Zip { + zipReader, err := zip.NewReader(data, data.Size()) if err != nil { - return nil, err + return err } for _, f := range zipReader.File { - if !strings.EqualFold(strings.TrimSuffix(f.Name, extIfFound), d.AssetName) { - continue - } - fileInArchive, err := f.Open() + data, err := f.Open() if err != nil { - return nil, err + return err } - bin, err := io.ReadAll(fileInArchive) - if err != nil { - return nil, err + if err := callback(f.Name, f.FileInfo(), data); err != nil { + return err } - _ = fileInArchive.Close() - return bin, nil + _ = data.Close() } - } else if d.Format == Tar { - gzipReader, err := gzip.NewReader(buff) + } else if format == Tar { + gzipReader, err := gzip.NewReader(data) if err != nil { - return nil, err + return err } tarReader := tar.NewReader(gzipReader) // iterate through the files in the archive @@ -179,20 +355,12 @@ func (d *GHReleaseDownloader) GetExecutableFromAsset() ([]byte, error) { break } if err != nil { - return nil, err - } - if !strings.EqualFold(strings.TrimSuffix(header.FileInfo().Name(), extIfFound), d.AssetName) { - continue + return err } - // if the file is not a directory, extract it - if !header.FileInfo().IsDir() { - bin, err := io.ReadAll(tarReader) - if err != nil { - return nil, err - } - return bin, nil + if err := callback(header.Name, header.FileInfo(), tarReader); err != nil { + return err } } } - return nil, fmt.Errorf("executable not found in archive") + return nil } diff --git a/update/gh_test.go b/update/gh_test.go new file mode 100644 index 0000000..01968de --- /dev/null +++ b/update/gh_test.go @@ -0,0 +1,47 @@ +//go:build update + +// update related tests are only executed when update tag is provided (ex: go test -tags update ./...) to avoid failures due to rate limiting +package updateutils + +import ( + "io" + "io/fs" + "testing" + + "github.com/stretchr/testify/require" +) + +// TestDownloadNucleiRelease tests downloading nuclei release +func TestDownloadNucleiRelease(t *testing.T) { + HideProgressBar = true + gh, err := NewghReleaseDownloader("nuclei") + require.Nil(t, err) + _, err = gh.GetExecutableFromAsset() + require.Nil(t, err) +} + +// TestDownloadNucleiTemplatesFromSource tests downloading nuclei-templates from source +func TestDownloadNucleiTemplatesFromSource(t *testing.T) { + gh, err := NewghReleaseDownloader("nuclei-templates") + require.Nil(t, err) + counter := 0 + callback := func(path string, fileInfo fs.FileInfo, data io.Reader) error { + _ = fileInfo.Name() + counter++ + return nil + } + err = gh.DownloadSourceWithCallback(false, callback) + require.Nil(t, err) + // actual content is lot more than 100 files + require.Greater(t, counter, 100) +} + +// TestDownloadToolWithDifferentName tests downloading a tool with different name than repo name +// by default repo name is considered as executable name +func TestDownloadToolWithDifferentName(t *testing.T) { + gh, err := NewghReleaseDownloader("interactsh") + require.Nil(t, err) + gh.SetToolName("interactsh-client") + _, err = gh.GetExecutableFromAsset() + require.Nil(t, err) +} diff --git a/update/types.go b/update/types.go index 54d4e43..ae44aac 100644 --- a/update/types.go +++ b/update/types.go @@ -13,8 +13,30 @@ type AssetFormat uint const ( Zip AssetFormat = iota Tar + Unknown ) +// FileExtension of this asset format +func (a AssetFormat) FileExtension() string { + if a == Zip { + return ".zip" + } else if a == Tar { + return ".tar.gz" + } + return "" +} + +func IdentifyAssetFormat(assetName string) AssetFormat { + switch { + case strings.HasSuffix(assetName, Zip.FileExtension()): + return Zip + case strings.HasSuffix(assetName, Tar.FileExtension()): + return Tar + default: + return Unknown + } +} + // Tool type Tool struct { Name string `json:"name"` @@ -23,24 +45,57 @@ type Tool struct { Assets map[string]string `json:"assets"` } +// Aurora instance +var Aurora aurora.Aurora = aurora.NewAurora(true) + // GetVersionDescription returns tags like (latest) or (outdated) or (dev) func GetVersionDescription(current string, latest string) string { + if strings.HasSuffix(current, "-dev") { + if IsDevReleaseOutdated(current, latest) { + return fmt.Sprintf("(%v)", Aurora.BrightRed("outdated")) + } else { + return fmt.Sprintf("(%v)", Aurora.Blue("development")) + } + } + if IsOutdated(current, latest) { + return fmt.Sprintf("(%v)", Aurora.BrightRed("outdated")) + } else { + return fmt.Sprintf("(%v)", Aurora.BrightGreen("latest")) + } +} + +// IsOutdated returns true if current version is outdated +func IsOutdated(current, latest string) bool { + if strings.HasSuffix(current, "-dev") { + return IsDevReleaseOutdated(current, latest) + } currentVer, _ := semver.NewVersion(current) latestVer, _ := semver.NewVersion(latest) - if strings.Contains(current, "dev") { - return fmt.Sprintf("(%v)", aurora.BrightBlue("dev")) - } if currentVer == nil || latestVer == nil { // fallback to naive comparison + return current != latest + } + return latestVer.GreaterThan(currentVer) +} + +// IsDevReleaseOutdated returns true if installed tool (dev version) is outdated +// ex: if installed tools is v2.9.1-dev and latest release is v2.9.1 then it is outdated +// since v2.9.1-dev is released and merged into main/master branch +func IsDevReleaseOutdated(current string, latest string) bool { + // remove -dev suffix + current = strings.TrimSuffix(current, "-dev") + currentVer, _ := semver.NewVersion(current) + latestVer, _ := semver.NewVersion(latest) + if currentVer == nil || latestVer == nil { if current == latest { - return fmt.Sprintf("(%v)", aurora.BrightGreen("latest")) + return true } else { - return fmt.Sprintf("(%v)", aurora.BrightRed("outdated")) + // can't compare, so consider it latest + return false } } - if latestVer.GreaterThan(currentVer) { - return fmt.Sprintf("(%v)", aurora.BrightRed("outdated")) - } else { - return fmt.Sprintf("(%v)", aurora.BrightGreen("latest")) + if latestVer.GreaterThan(currentVer) || latestVer.Equal(currentVer) { + return true } + return false } diff --git a/update/update.go b/update/update.go index a23de09..bb06ba4 100644 --- a/update/update.go +++ b/update/update.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "log" "net/http" "net/url" "os" @@ -31,28 +30,38 @@ var ( HideProgressBar = false VersionCheckTimeout = time.Duration(5) * time.Second DownloadUpdateTimeout = time.Duration(30) * time.Second - // Note: DefaultHttpClient is only used in VersionCheck Callback + // Note: DefaultHttpClient is only used in GetToolVersionCallback DefaultHttpClient *http.Client ) // GetUpdateToolCallback returns a callback function // that updates given tool if given version is older than latest gh release and exits func GetUpdateToolCallback(toolName, version string) func() { + return GetUpdateToolFromRepoCallback(toolName, version, "") +} + +// GetUpdateToolWithRepoCallback returns a callback function that is similar to GetUpdateToolCallback +// but it takes repoName as an argument (repoName can be either just repoName ex: `nuclei` or full repo Addr ex: `projectdiscovery/nuclei`) +func GetUpdateToolFromRepoCallback(toolName, version, repoName string) func() { return func() { - gh := NewghReleaseDownloader(toolName) - latest, err := gh.GetLatestRelease() + if repoName == "" { + repoName = toolName + } + gh, err := NewghReleaseDownloader(repoName) if err != nil { - gologger.Fatal().Label("updater").Msgf("failed to fetch latest release of %v", toolName) + gologger.Fatal().Label("updater").Msgf("failed to download latest release got %v", err) } - latestVersion, err := semver.NewVersion(latest.GetTagName()) + gh.SetToolName(toolName) + latestVersion, err := semver.NewVersion(gh.Latest.GetTagName()) if err != nil { - gologger.Fatal().Label("updater").Msgf("failed to parse semversion from tagname `%v` got %v", latest.GetTagName(), err) + gologger.Fatal().Label("updater").Msgf("failed to parse semversion from tagname `%v` got %v", gh.Latest.GetTagName(), err) } currentVersion, err := semver.NewVersion(version) if err != nil { gologger.Fatal().Label("updater").Msgf("failed to parse semversion from current version %v got %v", version, err) } - if !latestVersion.GreaterThan(currentVersion) { + // check if current version is outdated + if !IsOutdated(currentVersion.String(), latestVersion.String()) { gologger.Info().Msgf("%v is already updated to latest version", toolName) os.Exit(0) } @@ -62,17 +71,13 @@ func GetUpdateToolCallback(toolName, version string) func() { if err := updateOpts.CheckPermissions(); err != nil { gologger.Fatal().Label("updater").Msgf("update of %v %v -> %v failed , insufficient permission detected got: %v", toolName, currentVersion.String(), latestVersion.String(), err) } - - if err := gh.GetAssetIDFromRelease(latest); err != nil { - gologger.Fatal().Label("updater").Msgf("failed to find release of %v for platform %v %v got : %v", toolName, runtime.GOOS, runtime.GOARCH, err) - } bin, err := gh.GetExecutableFromAsset() if err != nil { gologger.Fatal().Label("updater").Msgf("executable %v not found in release asset `%v` got: %v", toolName, gh.AssetID, err) } if err = selfupdate.Apply(bytes.NewBuffer(bin), updateOpts); err != nil { - log.Printf("Error] update of %v %v -> %v failed, rolling back update", toolName, currentVersion.String(), latestVersion.String()) + gologger.Error().Msgf("update of %v %v -> %v failed, rolling back update", toolName, currentVersion.String(), latestVersion.String()) if err := selfupdate.RollbackError(err); err != nil { gologger.Fatal().Label("updater").Msgf("rollback of update of %v failed got %v,pls reinstall %v", toolName, err, toolName) } @@ -83,8 +88,13 @@ func GetUpdateToolCallback(toolName, version string) func() { gologger.Info().Msgf("%v sucessfully updated %v -> %v (latest)", toolName, currentVersion.String(), latestVersion.String()) if !HideReleaseNotes { - output := latest.GetBody() - if rendered, err := glamour.Render(output, "dark"); err == nil { + output := gh.Latest.GetBody() + // adjust colors for both dark / light terminal themes + r, err := glamour.NewTermRenderer(glamour.WithAutoStyle()) + if err != nil { + gologger.Error().Msgf("markdown rendering not supported: %v", err) + } + if rendered, err := r.Render(output); err == nil { output = rendered } else { gologger.Error().Msg(err.Error()) @@ -95,11 +105,12 @@ func GetUpdateToolCallback(toolName, version string) func() { } } -// GetVersionCheckCallback retuns a callback function -// that returns latest version of tool -func GetVersionCheckCallback(toolName string) func() (string, error) { +// GetToolVersionCallback returns a callback function that checks for updates of tool +// by sending a request to update check endpoint and returns latest version +// if repoName is empty then tool name is considered as repoName +func GetToolVersionCallback(toolName, version string) func() (string, error) { return func() (string, error) { - updateURL := fmt.Sprintf(UpdateCheckEndpoint, toolName) + "?" + getMetaParams() + updateURL := fmt.Sprintf(UpdateCheckEndpoint, toolName) + "?" + getpdtmParams(version) if DefaultHttpClient == nil { // not needed but as a precaution to avoid nil panics DefaultHttpClient = http.DefaultClient @@ -130,15 +141,21 @@ func GetVersionCheckCallback(toolName string) func() (string, error) { } } -// getMetaParams returns encoded query parameters sent to update check endpoint -func getMetaParams() string { +// getpdtmParams returns encoded query parameters sent to update check endpoint +func getpdtmParams(version string) string { params := &url.Values{} params.Add("os", runtime.GOOS) params.Add("arch", runtime.GOARCH) params.Add("go_version", runtime.Version()) + params.Add("v", version) return params.Encode() } +// Deprecated: use GetToolVersionCheckCallback instead +func GetVersionCheckCallback(toolName string) func() (string, error) { + return GetToolVersionCallback(toolName, "") +} + func init() { DefaultHttpClient = &http.Client{ Timeout: VersionCheckTimeout, diff --git a/update/utils_test.go b/update/utils_test.go new file mode 100644 index 0000000..5da1c0f --- /dev/null +++ b/update/utils_test.go @@ -0,0 +1,47 @@ +package updateutils + +import ( + "testing" + + "github.com/logrusorgru/aurora" +) + +func TestGetVersionDescription(t *testing.T) { + Aurora = aurora.NewAurora(false) + tests := []struct { + current string + latest string + want string + }{ + { + current: "v2.9.1-dev", + latest: "v2.9.1", + want: "(outdated)", + }, + { + current: "v2.9.1-dev", + latest: "v2.9.2", + want: "(outdated)", + }, + { + current: "v2.9.1-dev", + latest: "v2.9.0", + want: "(development)", + }, + { + current: "v2.9.1", + latest: "v2.9.1", + want: "(latest)", + }, + { + current: "v2.9.1", + latest: "v2.9.2", + want: "(outdated)", + }, + } + for _, test := range tests { + if GetVersionDescription(test.current, test.latest) != test.want { + t.Errorf("GetVersionDescription(%v, %v) = %v, want %v", test.current, test.latest, GetVersionDescription(test.current, test.latest), test.want) + } + } +}