Skip to content

Commit

Permalink
Fix/5163 retry download upgrade verifiers (#6276)
Browse files Browse the repository at this point in the history
* fix(5163): added context to the verify interface, updated relevant calls, added backoff roundtripper

* fix(5163): added backoff roundtripper, added tests for asc file download

* fix(5163): remove log

* fix(5163): ran mage fmt

* fix(5163): removed unused roundtripper

* fix(5163): closing response body on unsuccesful responses

* fix(5163): added changelog

* fix(5163): added comment explaining the backoff function, added nolint comment

* fix(5163): added logic to handle resetting request body in the retry roundtripper, added unit tests, updated comments

* fix(5163): fixed linting errors

* fix(5164): ran mage check
  • Loading branch information
kaanyalti authored Dec 18, 2024
1 parent 0e68ce1 commit a5eb77b
Show file tree
Hide file tree
Showing 14 changed files with 289 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Kind can be one of:
# - breaking-change: a change to previously-documented behavior
# - deprecation: functionality that is being removed in a later release
# - bug-fix: fixes a problem in a previous version
# - enhancement: extends functionality but does not break or fix existing behavior
# - feature: new functionality
# - known-issue: problems that we are aware of in a given version
# - security: impacts on the security of a product or a user’s deployment.
# - upgrade: important information for someone upgrading from a prior version
# - other: does not fit into any of the other categories
kind: bug-fix

# Change summary; a 80ish characters long description of the change.
summary: added retries for requesting download verifiers when upgrading the agent

# Long description; in case the summary is not enough to describe the change
# this field accommodate a description without length limits.
# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment.
#description:

# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc.
component: "elastic-agent"
# PR URL; optional; the PR number that added the changeset.
# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added.
# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number.
# Please provide it if you are adding a fragment for a different PR.
pr: https://github.com/elastic/elastic-agent/pull/6276
# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of).
# If not present is automatically filled by the tooling with the issue linked to the PR number.
#issue: https://github.com/owner/repo/1234
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package composed

import (
"context"
goerrors "errors"

"github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact"
Expand Down Expand Up @@ -39,11 +40,11 @@ func NewVerifier(log *logger.Logger, verifiers ...download.Verifier) *Verifier {
}

// Verify checks the package from configured source.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
var errs []error

for _, verifier := range v.vv {
e := verifier.Verify(a, version, skipDefaultPgp, pgpBytes...)
e := verifier.Verify(ctx, a, version, skipDefaultPgp, pgpBytes...)
if e == nil {
// Success
return nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package composed

import (
"context"
"errors"
"testing"

Expand All @@ -24,7 +25,7 @@ func (d *ErrorVerifier) Name() string {
return "error"
}

func (d *ErrorVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *ErrorVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return errors.New("failing")
}
Expand All @@ -39,7 +40,7 @@ func (d *FailVerifier) Name() string {
return "fail"
}

func (d *FailVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *FailVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return &download.InvalidSignatureError{File: "", Err: errors.New("invalid signature")}
}
Expand All @@ -54,7 +55,7 @@ func (d *SuccVerifier) Name() string {
return "succ"
}

func (d *SuccVerifier) Verify(artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
func (d *SuccVerifier) Verify(context.Context, artifact.Artifact, agtversion.ParsedSemVer, bool, ...string) error {
d.called = true
return nil
}
Expand Down Expand Up @@ -90,7 +91,7 @@ func TestVerifier(t *testing.T) {
testVersion := agtversion.NewParsedSemVer(1, 2, 3, "", "")
for _, tc := range testCases {
d := NewVerifier(log, tc.verifiers[0], tc.verifiers[1], tc.verifiers[2])
err := d.Verify(artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false)
err := d.Verify(context.Background(), artifact.Artifact{Name: "a", Cmd: "a", Artifact: "a/a"}, *testVersion, false)

assert.Equal(t, tc.expectedResult, err == nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package fs

import (
"context"
"fmt"
"net/http"
"os"
Expand Down Expand Up @@ -65,7 +66,7 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri

// Verify checks downloaded package on preconfigured
// location against a key stored on elastic.co website.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch())
if err != nil {
return fmt.Errorf("could not get artifact name: %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ import (

var testVersion = agtversion.NewParsedSemVer(7, 5, 1, "", "")

var (
agentSpec = artifact.Artifact{
Name: "Elastic Agent",
Cmd: "elastic-agent",
Artifact: "beat/elastic-agent"}
)
var agentSpec = artifact.Artifact{
Name: "Elastic Agent",
Cmd: "elastic-agent",
Artifact: "beat/elastic-agent",
}

func TestFetchVerify(t *testing.T) {
// See docs/pgp-sign-verify-artifact.md for how to generate a key, export
Expand All @@ -47,7 +46,8 @@ func TestFetchVerify(t *testing.T) {
targetPath := filepath.Join("testdata", "download")
ctx := context.Background()
a := artifact.Artifact{
Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent"}
Name: "elastic-agent", Cmd: "elastic-agent", Artifact: "beats/elastic-agent",
}
version := agtversion.NewParsedSemVer(8, 0, 0, "", "")

filename := "elastic-agent-8.0.0-darwin-x86_64.tar.gz"
Expand Down Expand Up @@ -80,7 +80,7 @@ func TestFetchVerify(t *testing.T) {
// first download verify should fail:
// download skipped, as invalid package is prepared upfront
// verify fails and cleans download
err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
var checksumErr *download.ChecksumMismatchError
require.ErrorAs(t, err, &checksumErr)

Expand Down Expand Up @@ -109,7 +109,7 @@ func TestFetchVerify(t *testing.T) {
_, err = os.Stat(ascTargetFilePath)
require.NoError(t, err)

err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
require.NoError(t, err)

// Bad GPG public key.
Expand All @@ -126,7 +126,7 @@ func TestFetchVerify(t *testing.T) {

// Missing .asc file.
{
err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
require.Error(t, err)

// Don't delete these files when GPG validation failure.
Expand All @@ -139,7 +139,7 @@ func TestFetchVerify(t *testing.T) {
err = os.WriteFile(targetFilePath+".asc", []byte("bad sig"), 0o600)
require.NoError(t, err)

err = verifier.Verify(a, *version, false)
err = verifier.Verify(ctx, a, *version, false)
var invalidSigErr *download.InvalidSignatureError
assert.ErrorAs(t, err, &invalidSigErr)

Expand All @@ -157,7 +157,8 @@ func prepareFetchVerifyTests(
targetDir,
filename,
targetFilePath,
hashTargetFilePath string) error {
hashTargetFilePath string,
) error {
sourceFilePath := filepath.Join(dropPath, filename)
hashSourceFilePath := filepath.Join(dropPath, filename+".sha512")

Expand Down Expand Up @@ -202,6 +203,7 @@ func TestVerify(t *testing.T) {

for _, tc := range tt {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()
log, obs := loggertest.New("TestVerify")
targetDir := t.TempDir()

Expand All @@ -220,7 +222,7 @@ func TestVerify(t *testing.T) {
pgpKey := prepareTestCase(t, agentSpec, testVersion, config)

testClient := NewDownloader(config)
artifactPath, err := testClient.Download(context.Background(), agentSpec, testVersion)
artifactPath, err := testClient.Download(ctx, agentSpec, testVersion)
require.NoError(t, err, "fs.Downloader could not download artifacts")
_, err = testClient.DownloadAsc(context.Background(), agentSpec, *testVersion)
require.NoError(t, err, "fs.Downloader could not download artifacts .asc file")
Expand All @@ -231,7 +233,7 @@ func TestVerify(t *testing.T) {
testVerifier, err := NewVerifier(log, config, pgpKey)
require.NoError(t, err)

err = testVerifier.Verify(agentSpec, *testVersion, false, tc.RemotePGPUris...)
err = testVerifier.Verify(ctx, agentSpec, *testVersion, false, tc.RemotePGPUris...)
require.NoError(t, err)

// log message informing remote PGP was skipped
Expand All @@ -246,7 +248,6 @@ func TestVerify(t *testing.T) {
// It creates the necessary key to sing the artifact and returns the public key
// to verify the signature.
func prepareTestCase(t *testing.T, a artifact.Artifact, version *agtversion.ParsedSemVer, cfg *artifact.Config) []byte {

filename, err := artifact.GetArtifactName(a, *version, cfg.OperatingSystem, cfg.Architecture)
require.NoErrorf(t, err, "could not get artifact name")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,28 @@ func getTestCases() []testCase {
}
}

func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
type extResCode map[string]struct {
resCode int
count int
}

type testDials struct {
extResCode
}

func (td *testDials) withExtResCode(k string, statusCode int, count int) {
td.extResCode[k] = struct {
resCode int
count int
}{statusCode, count}
}

func (td *testDials) reset() {
*td = testDials{extResCode: make(extResCode)}
}

func getElasticCoServer(t *testing.T) (*httptest.Server, []byte, *testDials) {
td := testDials{extResCode: make(extResCode)}
correctValues := map[string]struct{}{
fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "i386.deb"): {},
fmt.Sprintf("%s-%s-%s", beatSpec.Cmd, version, "amd64.deb"): {},
Expand All @@ -81,7 +102,6 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
ext = ".tar.gz"
}
packageName = strings.TrimSuffix(packageName, ext)

switch ext {
case ".sha512":
resp = []byte(fmt.Sprintf("%x %s", hash, packageName))
Expand All @@ -103,11 +123,17 @@ func getElasticCoServer(t *testing.T) (*httptest.Server, []byte) {
return
}

if v, ok := td.extResCode[ext]; ok && v.count != 0 {
w.WriteHeader(v.resCode)
v.count--
td.extResCode[ext] = v
}

_, err := w.Write(resp)
assert.NoErrorf(t, err, "mock elastic.co server: failes writing response")
})

return httptest.NewServer(handler), pub
return httptest.NewServer(handler), pub, &td
}

func getElasticCoClient(server *httptest.Server) http.Client {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func TestDownload(t *testing.T) {
log, _ := logger.New("", false)
timeout := 30 * time.Second
testCases := getTestCases()
server, _ := getElasticCoServer(t)
server, _, _ := getElasticCoServer(t)
elasticClient := getElasticCoClient(server)

config := &artifact.Config{
Expand Down Expand Up @@ -359,7 +359,6 @@ type downloadHttpResponse struct {
}

func TestDownloadVersion(t *testing.T) {

type fields struct {
config *artifact.Config
}
Expand Down Expand Up @@ -485,7 +484,6 @@ func TestDownloadVersion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

targetDirPath := t.TempDir()

handleDownload := func(rw http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -527,5 +525,4 @@ func TestDownloadVersion(t *testing.T) {
assert.Equalf(t, filepath.Join(targetDirPath, tt.want), got, "Download(%v, %v)", tt.args.a, tt.args.version)
})
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ func NewVerifier(log *logger.Logger, config *artifact.Config, pgp []byte) (*Veri
httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper {
return download.WithHeaders(rt, download.Headers)
}),
httpcommon.WithModRoundtripper(func(rt http.RoundTripper) http.RoundTripper {
return WithBackoff(rt, log)
}),
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -88,7 +91,7 @@ func (v *Verifier) Reload(c *artifact.Config) error {

// Verify checks downloaded package on preconfigured
// location against a key stored on elastic.co website.
func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
func (v *Verifier) Verify(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultPgp bool, pgpBytes ...string) error {
artifactPath, err := artifact.GetArtifactPath(a, version, v.config.OS(), v.config.Arch(), v.config.TargetDirectory)
if err != nil {
return errors.New(err, "retrieving package path")
Expand All @@ -98,7 +101,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer,
return fmt.Errorf("failed to verify SHA512 hash: %w", err)
}

if err = v.verifyAsc(a, version, skipDefaultPgp, pgpBytes...); err != nil {
if err = v.verifyAsc(ctx, a, version, skipDefaultPgp, pgpBytes...); err != nil {
var invalidSignatureErr *download.InvalidSignatureError
if errors.As(err, &invalidSignatureErr) {
if err := os.Remove(artifactPath); err != nil {
Expand All @@ -116,7 +119,7 @@ func (v *Verifier) Verify(a artifact.Artifact, version agtversion.ParsedSemVer,
return nil
}

func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error {
func (v *Verifier) verifyAsc(ctx context.Context, a artifact.Artifact, version agtversion.ParsedSemVer, skipDefaultKey bool, pgpSources ...string) error {
filename, err := artifact.GetArtifactName(a, version, v.config.OS(), v.config.Arch())
if err != nil {
return errors.New(err, "retrieving package name")
Expand All @@ -132,7 +135,7 @@ func (v *Verifier) verifyAsc(a artifact.Artifact, version agtversion.ParsedSemVe
return errors.New(err, "composing URI for fetching asc file", errors.TypeNetwork)
}

ascBytes, err := v.getPublicAsc(ascURI)
ascBytes, err := v.getPublicAsc(ctx, ascURI)
if err != nil {
return errors.New(err, fmt.Sprintf("fetching asc file from %s", ascURI), errors.TypeNetwork, errors.M(errors.MetaKeyURI, ascURI))
}
Expand Down Expand Up @@ -163,8 +166,8 @@ func (v *Verifier) composeURI(filename, artifactName string) (string, error) {
return uri.String(), nil
}

func (v *Verifier) getPublicAsc(sourceURI string) ([]byte, error) {
ctx, cancelFn := context.WithTimeout(context.Background(), 30*time.Second)
func (v *Verifier) getPublicAsc(ctx context.Context, sourceURI string) ([]byte, error) {
ctx, cancelFn := context.WithTimeout(ctx, 30*time.Second)
defer cancelFn()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURI, nil)
if err != nil {
Expand Down
Loading

0 comments on commit a5eb77b

Please sign in to comment.