Skip to content

Commit

Permalink
fix: Progressbar handling (#336)
Browse files Browse the repository at this point in the history
Co-authored-by: Kemal Hadimli <disq@users.noreply.github.com>
  • Loading branch information
disq and disq authored Jun 6, 2024
1 parent 709cff1 commit 538487d
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 20 deletions.
10 changes: 9 additions & 1 deletion managedplugin/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func isDockerImageAvailable(ctx context.Context, imageName string) (bool, error)
return len(images) > 0, nil
}

func pullDockerImage(ctx context.Context, imageName string, authToken string, teamName string, dockerHubAuth string) error {
func pullDockerImage(ctx context.Context, imageName string, authToken string, teamName string, dockerHubAuth string, dops DownloaderOptions) error {
// Pull the image
additionalHeaders := make(map[string]string)
opts := image.PullOptions{}
Expand Down Expand Up @@ -143,6 +143,14 @@ func pullDockerImage(ctx context.Context, imageName string, authToken string, te
}
defer out.Close()

if dops.NoProgress {
_, err = io.Copy(io.Discard, out)
if err != nil {
return fmt.Errorf("failed to copy image pull output: %v", err)
}
return nil
}

// Create a progress reader to display the download progress
pr := &dockerProgressReader{
decoder: json.NewDecoder(out),
Expand Down
31 changes: 18 additions & 13 deletions managedplugin/download.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,11 @@ type HubDownloadOptions struct {
PluginName string
PluginVersion string
}
type DownloaderOptions struct {
NoProgress bool
}

func DownloadPluginFromHub(ctx context.Context, c *cloudquery_api.ClientWithResponses, ops HubDownloadOptions) error {
func DownloadPluginFromHub(ctx context.Context, c *cloudquery_api.ClientWithResponses, ops HubDownloadOptions, dops DownloaderOptions) error {
downloadDir := filepath.Dir(ops.LocalPath)
if _, err := os.Stat(ops.LocalPath); err == nil {
return nil
Expand Down Expand Up @@ -162,7 +165,7 @@ func DownloadPluginFromHub(ctx context.Context, c *cloudquery_api.ClientWithResp
return fmt.Errorf("failed to get plugin metadata from hub: empty location from response")
}
pluginZipPath := ops.LocalPath + ".zip"
writtenChecksum, err := downloadFile(ctx, pluginZipPath, location)
writtenChecksum, err := downloadFile(ctx, pluginZipPath, location, dops)
if err != nil {
return fmt.Errorf("failed to download plugin: %w", err)
}
Expand Down Expand Up @@ -236,7 +239,7 @@ func downloadPluginAssetFromHub(ctx context.Context, c *cloudquery_api.ClientWit
}
}

func DownloadPluginFromGithub(ctx context.Context, logger zerolog.Logger, localPath string, org string, name string, version string, typ PluginType) error {
func DownloadPluginFromGithub(ctx context.Context, logger zerolog.Logger, localPath string, org string, name string, version string, typ PluginType, dops DownloaderOptions) error {
downloadDir := filepath.Dir(localPath)
pluginZipPath := localPath + ".zip"

Expand All @@ -253,7 +256,7 @@ func DownloadPluginFromGithub(ctx context.Context, logger zerolog.Logger, localP
return fmt.Errorf("failed to get plugin url: %w", err)
}
logger.Debug().Msg(fmt.Sprintf("Downloading %s", downloadURL))
if _, err := downloadFile(ctx, pluginZipPath, downloadURL); err != nil {
if _, err := downloadFile(ctx, pluginZipPath, downloadURL, dops); err != nil {
return fmt.Errorf("failed to download plugin: %w", err)
}

Expand Down Expand Up @@ -301,20 +304,16 @@ func DownloadPluginFromGithub(ctx context.Context, logger zerolog.Logger, localP
return nil
}

func downloadFile(ctx context.Context, localPath string, downloadURL string) (string, error) {
func downloadFile(ctx context.Context, localPath string, downloadURL string, dops DownloaderOptions) (string, error) {
// Create the file
out, err := os.Create(localPath)
if err != nil {
return "", fmt.Errorf("failed to create file %s: %w", localPath, err)
}
defer out.Close()

return downloadFileFromURL(ctx, out, downloadURL)
}

func downloadFileFromURL(ctx context.Context, out *os.File, downloadURL string) (string, error) {
checksum := ""
err := retry.Do(func() error {
err = retry.Do(func() error {
checksum = ""
// Get the data
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
Expand Down Expand Up @@ -344,11 +343,17 @@ func downloadFileFromURL(ctx context.Context, out *os.File, downloadURL string)
urlForLog = parsedURL.String()
}
fmt.Printf("Downloading %s\n", urlForLog)
bar := downloadProgressBar(resp.ContentLength, "Downloading")

s := sha256.New()
// Writer the body to file
_, err = io.Copy(io.MultiWriter(out, bar, s), resp.Body)
writers := []io.Writer{out, s}

if !dops.NoProgress {
bar := downloadProgressBar(resp.ContentLength, "Downloading")
writers = append(writers, bar)
}

// Write the body to file
_, err = io.Copy(io.MultiWriter(writers...), resp.Body)
if err != nil {
return fmt.Errorf("failed to copy body to file %s: %w", out.Name(), err)
}
Expand Down
6 changes: 4 additions & 2 deletions managedplugin/download_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestDownloadPluginFromGithubIntegration(t *testing.T) {
logger := zerolog.Logger{}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
err := DownloadPluginFromGithub(context.Background(), logger, path.Join(tmp, tc.name), tc.org, tc.plugin, tc.version, tc.typ)
err := DownloadPluginFromGithub(context.Background(), logger, path.Join(tmp, tc.name), tc.org, tc.plugin, tc.version, tc.typ, DownloaderOptions{})
if (err != nil) != tc.wantErr {
t.Errorf("DownloadPluginFromGithub() error = %v, wantErr %v", err, tc.wantErr)
return
Expand Down Expand Up @@ -64,7 +64,9 @@ func TestDownloadPluginFromCloudQueryHub(t *testing.T) {
PluginKind: tc.typ.String(),
PluginName: tc.plugin,
PluginVersion: tc.version,
})
},
DownloaderOptions{},
)
if (err != nil) != tc.wantErr {
t.Errorf("TestDownloadPluginFromCloudQueryIntegration() error = %v, wantErr %v", err, tc.wantErr)
return
Expand Down
6 changes: 6 additions & 0 deletions managedplugin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ func WithNoExec() Option {
}
}

func WithNoProgress() Option {
return func(c *Client) {
c.noProgress = true
}
}

func WithOtelEndpoint(endpoint string) Option {
return func(c *Client) {
c.otelEndpoint = endpoint
Expand Down
12 changes: 8 additions & 4 deletions managedplugin/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ type Client struct {
config Config
noSentry bool
noExec bool
noProgress bool
cqDockerHost string
otelEndpoint string
otelEndpointInsecure bool
Expand Down Expand Up @@ -163,6 +164,9 @@ func NewClient(ctx context.Context, typ PluginType, config Config, opts ...Optio
}

func (c *Client) downloadPlugin(ctx context.Context, typ PluginType) error {
dops := DownloaderOptions{
NoProgress: c.noProgress,
}
switch c.config.Registry {
case RegistryGrpc:
return nil // GRPC plugins are not downloaded
Expand All @@ -176,12 +180,12 @@ func (c *Client) downloadPlugin(ctx context.Context, typ PluginType) error {
org, name := pathSplit[0], pathSplit[1]
c.LocalPath = filepath.Join(c.directory, "plugins", typ.String(), org, name, c.config.Version, "plugin")
c.LocalPath = WithBinarySuffix(c.LocalPath)
return DownloadPluginFromGithub(ctx, c.logger, c.LocalPath, org, name, c.config.Version, typ)
return DownloadPluginFromGithub(ctx, c.logger, c.LocalPath, org, name, c.config.Version, typ, dops)
case RegistryDocker:
if imageAvailable, err := isDockerImageAvailable(ctx, c.config.Path); err != nil {
return err
} else if !imageAvailable {
return pullDockerImage(ctx, c.config.Path, c.authToken, c.teamName, c.dockerAuth)
return pullDockerImage(ctx, c.config.Path, c.authToken, c.teamName, c.dockerAuth, dops)
}
return nil
case RegistryCloudQuery:
Expand Down Expand Up @@ -217,11 +221,11 @@ func (c *Client) downloadPlugin(ctx context.Context, typ PluginType) error {
if imageAvailable, err := isDockerImageAvailable(ctx, path); err != nil {
return err
} else if !imageAvailable {
return pullDockerImage(ctx, path, c.authToken, c.teamName, "")
return pullDockerImage(ctx, path, c.authToken, c.teamName, "", dops)
}
return nil
}
return DownloadPluginFromHub(ctx, hubClient, ops)
return DownloadPluginFromHub(ctx, hubClient, ops, dops)
default:
return fmt.Errorf("unknown registry %s", c.config.Registry.String())
}
Expand Down

0 comments on commit 538487d

Please sign in to comment.