diff --git a/changelog/fragments/1695685534-standalone-prevent-quick-upgrades.yaml b/changelog/fragments/1695685534-standalone-prevent-quick-upgrades.yaml new file mode 100644 index 00000000000..fffa59953be --- /dev/null +++ b/changelog/fragments/1695685534-standalone-prevent-quick-upgrades.yaml @@ -0,0 +1,32 @@ +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: bug-fix + +# Change summary; a 80ish characters long description of the change. +summary: Prevent a standalone Elastic Agent from being upgraded if an upgrade is already in progress. + +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# NOTE: This field will be rendered only for breaking-change and known-issue kinds at the moment. +#description: + +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: elastic-agent + +# PR URL; optional; the PR number that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +pr: https://github.com/elastic/elastic-agent/pull/3473 + +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +issue: https://github.com/elastic/elastic-agent/issues/2706 diff --git a/internal/pkg/agent/application/upgrade/upgrade.go b/internal/pkg/agent/application/upgrade/upgrade.go index e653cf54525..d03a55fb336 100644 --- a/internal/pkg/agent/application/upgrade/upgrade.go +++ b/internal/pkg/agent/application/upgrade/upgrade.go @@ -11,6 +11,9 @@ import ( "path/filepath" "strings" + "github.com/elastic/elastic-agent/pkg/control/v2/client" + "github.com/elastic/elastic-agent/pkg/control/v2/cproto" + "github.com/otiai10/copy" "go.elastic.co/apm" @@ -366,3 +369,28 @@ func copyDir(l *logger.Logger, from, to string, ignoreErrs bool) error { OnError: onErr, }) } + +// IsInProgress checks if an Elastic Agent upgrade is already in progress. It +// returns true if so and false if not. +// `c client.Client` is expected to be a connected client. +func IsInProgress(c client.Client, watcherPIDsFetcher func() ([]int, error)) (bool, error) { + // First we check if any Upgrade Watcher processes are running. If they are, + // it means an upgrade is in progress. We check this before checking the Elastic + // Agent's status because the Elastic Agent GRPC server may briefly be + // unavailable during an upgrade and so the client connection might fail. + watcherPIDs, err := watcherPIDsFetcher() + if err != nil { + return false, fmt.Errorf("failed to determine if upgrade watcher is running: %w", err) + } + if len(watcherPIDs) > 0 { + return true, nil + } + + // Next we check the Elastic Agent's status using the GRPC client. + state, err := c.State(context.Background()) + if err != nil { + return false, fmt.Errorf("failed to get agent state: %w", err) + } + + return state.State == cproto.State_UPGRADING, nil +} diff --git a/internal/pkg/agent/application/upgrade/upgrade_test.go b/internal/pkg/agent/application/upgrade/upgrade_test.go index 84a2977e84b..edbf1dd9b6a 100644 --- a/internal/pkg/agent/application/upgrade/upgrade_test.go +++ b/internal/pkg/agent/application/upgrade/upgrade_test.go @@ -5,6 +5,7 @@ package upgrade import ( + "context" "fmt" "io/ioutil" "os" @@ -13,9 +14,14 @@ import ( "strings" "testing" + "github.com/elastic/elastic-agent/pkg/control/v2/client" + "github.com/elastic/elastic-agent/pkg/control/v2/client/mocks" + "github.com/elastic/elastic-agent/pkg/control/v2/cproto" + "github.com/gofrs/flock" "github.com/stretchr/testify/require" + "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/release" "github.com/elastic/elastic-agent/pkg/core/logger" ) @@ -134,3 +140,72 @@ func TestShutdownCallback(t *testing.T) { require.NoError(t, err, "reading file failed") require.Equal(t, content, newContent, "contents are not equal") } + +func TestIsInProgress(t *testing.T) { + tests := map[string]struct { + state cproto.State + stateErr string + watcherPIDsFetcher func() ([]int, error) + + expected bool + expectedErr string + }{ + "state_error": { + state: cproto.State_STARTING, + stateErr: "some error", + watcherPIDsFetcher: func() ([]int, error) { return nil, nil }, + + expected: false, + expectedErr: "failed to get agent state: some error", + }, + "state_upgrading": { + state: cproto.State_UPGRADING, + stateErr: "", + watcherPIDsFetcher: func() ([]int, error) { return nil, nil }, + + expected: true, + expectedErr: "", + }, + "state_healthy_no_watcher": { + state: cproto.State_HEALTHY, + stateErr: "", + watcherPIDsFetcher: func() ([]int, error) { return []int{}, nil }, + + expected: false, + expectedErr: "", + }, + "state_healthy_with_watcher": { + state: cproto.State_HEALTHY, + stateErr: "", + watcherPIDsFetcher: func() ([]int, error) { return []int{9999}, nil }, + + expected: true, + expectedErr: "", + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + // Expect client.State() call to be made only if no Upgrade Watcher PIDs + // are returned (i.e. no Upgrade Watcher is found to be running). + mc := mocks.NewClient(t) + if test.watcherPIDsFetcher != nil { + pids, _ := test.watcherPIDsFetcher() + if len(pids) == 0 { + if test.stateErr != "" { + mc.EXPECT().State(context.Background()).Return(nil, errors.New(test.stateErr)).Once() + } else { + mc.EXPECT().State(context.Background()).Return(&client.AgentState{State: test.state}, nil).Once() + } + } + } + + inProgress, err := IsInProgress(mc, test.watcherPIDsFetcher) + if test.expectedErr != "" { + require.Equal(t, test.expectedErr, err.Error()) + } else { + require.Equal(t, test.expected, inProgress) + } + }) + } +} diff --git a/internal/pkg/agent/cmd/upgrade.go b/internal/pkg/agent/cmd/upgrade.go index 63088d6df17..767560f873f 100644 --- a/internal/pkg/agent/cmd/upgrade.go +++ b/internal/pkg/agent/cmd/upgrade.go @@ -10,11 +10,13 @@ import ( "io/ioutil" "os" + "github.com/spf13/cobra" + "github.com/elastic/elastic-agent/pkg/control" "github.com/elastic/elastic-agent/pkg/control/v2/client" + "github.com/elastic/elastic-agent/pkg/utils" - "github.com/spf13/cobra" - + "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade" "github.com/elastic/elastic-agent/internal/pkg/agent/application/upgrade/artifact/download" "github.com/elastic/elastic-agent/internal/pkg/agent/errors" "github.com/elastic/elastic-agent/internal/pkg/cli" @@ -64,6 +66,14 @@ func upgradeCmd(streams *cli.IOStreams, cmd *cobra.Command, args []string) error } defer c.Disconnect() + isBeingUpgraded, err := upgrade.IsInProgress(c, utils.GetWatcherPIDs) + if err != nil { + return fmt.Errorf("failed to check if upgrade is already in progress: %w", err) + } + if isBeingUpgraded { + return errors.New("an upgrade is already in progress; please try again later.") + } + skipVerification, _ := cmd.Flags().GetBool(flagSkipVerify) var pgpChecks []string if !skipVerification { diff --git a/pkg/control/v2/client/client.go b/pkg/control/v2/client/client.go index 2aea8f39d6b..c2f593440fc 100644 --- a/pkg/control/v2/client/client.go +++ b/pkg/control/v2/client/client.go @@ -156,6 +156,7 @@ type DiagnosticComponentResult struct { } // Client communicates to Elastic Agent through the control protocol. +// go:generate mockery --name Client type Client interface { // Connect connects to the running Elastic Agent. Connect(ctx context.Context) error diff --git a/pkg/control/v2/client/mocks/client.go b/pkg/control/v2/client/mocks/client.go new file mode 100644 index 00000000000..20d4a6a9b81 --- /dev/null +++ b/pkg/control/v2/client/mocks/client.go @@ -0,0 +1,629 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +// Code generated by mockery v2.24.0. DO NOT EDIT. + +package mocks + +import ( + context "context" + + client "github.com/elastic/elastic-agent/pkg/control/v2/client" + + cproto "github.com/elastic/elastic-agent/pkg/control/v2/cproto" + + mock "github.com/stretchr/testify/mock" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +type Client_Expecter struct { + mock *mock.Mock +} + +func (_m *Client) EXPECT() *Client_Expecter { + return &Client_Expecter{mock: &_m.Mock} +} + +// Configure provides a mock function with given fields: ctx, config +func (_m *Client) Configure(ctx context.Context, config string) error { + ret := _m.Called(ctx, config) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, config) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Client_Configure_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Configure' +type Client_Configure_Call struct { + *mock.Call +} + +// Configure is a helper method to define mock.On call +// - ctx context.Context +// - config string +func (_e *Client_Expecter) Configure(ctx interface{}, config interface{}) *Client_Configure_Call { + return &Client_Configure_Call{Call: _e.mock.On("Configure", ctx, config)} +} + +func (_c *Client_Configure_Call) Run(run func(ctx context.Context, config string)) *Client_Configure_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string)) + }) + return _c +} + +func (_c *Client_Configure_Call) Return(_a0 error) *Client_Configure_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Client_Configure_Call) RunAndReturn(run func(context.Context, string) error) *Client_Configure_Call { + _c.Call.Return(run) + return _c +} + +// Connect provides a mock function with given fields: ctx +func (_m *Client) Connect(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Client_Connect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Connect' +type Client_Connect_Call struct { + *mock.Call +} + +// Connect is a helper method to define mock.On call +// - ctx context.Context +func (_e *Client_Expecter) Connect(ctx interface{}) *Client_Connect_Call { + return &Client_Connect_Call{Call: _e.mock.On("Connect", ctx)} +} + +func (_c *Client_Connect_Call) Run(run func(ctx context.Context)) *Client_Connect_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Client_Connect_Call) Return(_a0 error) *Client_Connect_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Client_Connect_Call) RunAndReturn(run func(context.Context) error) *Client_Connect_Call { + _c.Call.Return(run) + return _c +} + +// DiagnosticAgent provides a mock function with given fields: ctx, additionalDiags +func (_m *Client) DiagnosticAgent(ctx context.Context, additionalDiags []cproto.AdditionalDiagnosticRequest) ([]client.DiagnosticFileResult, error) { + ret := _m.Called(ctx, additionalDiags) + + var r0 []client.DiagnosticFileResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []cproto.AdditionalDiagnosticRequest) ([]client.DiagnosticFileResult, error)); ok { + return rf(ctx, additionalDiags) + } + if rf, ok := ret.Get(0).(func(context.Context, []cproto.AdditionalDiagnosticRequest) []client.DiagnosticFileResult); ok { + r0 = rf(ctx, additionalDiags) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]client.DiagnosticFileResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []cproto.AdditionalDiagnosticRequest) error); ok { + r1 = rf(ctx, additionalDiags) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_DiagnosticAgent_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DiagnosticAgent' +type Client_DiagnosticAgent_Call struct { + *mock.Call +} + +// DiagnosticAgent is a helper method to define mock.On call +// - ctx context.Context +// - additionalDiags []cproto.AdditionalDiagnosticRequest +func (_e *Client_Expecter) DiagnosticAgent(ctx interface{}, additionalDiags interface{}) *Client_DiagnosticAgent_Call { + return &Client_DiagnosticAgent_Call{Call: _e.mock.On("DiagnosticAgent", ctx, additionalDiags)} +} + +func (_c *Client_DiagnosticAgent_Call) Run(run func(ctx context.Context, additionalDiags []cproto.AdditionalDiagnosticRequest)) *Client_DiagnosticAgent_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].([]cproto.AdditionalDiagnosticRequest)) + }) + return _c +} + +func (_c *Client_DiagnosticAgent_Call) Return(_a0 []client.DiagnosticFileResult, _a1 error) *Client_DiagnosticAgent_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_DiagnosticAgent_Call) RunAndReturn(run func(context.Context, []cproto.AdditionalDiagnosticRequest) ([]client.DiagnosticFileResult, error)) *Client_DiagnosticAgent_Call { + _c.Call.Return(run) + return _c +} + +// DiagnosticComponents provides a mock function with given fields: ctx, additionalDiags, components +func (_m *Client) DiagnosticComponents(ctx context.Context, additionalDiags []cproto.AdditionalDiagnosticRequest, components ...client.DiagnosticComponentRequest) ([]client.DiagnosticComponentResult, error) { + _va := make([]interface{}, len(components)) + for _i := range components { + _va[_i] = components[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, additionalDiags) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []client.DiagnosticComponentResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, []cproto.AdditionalDiagnosticRequest, ...client.DiagnosticComponentRequest) ([]client.DiagnosticComponentResult, error)); ok { + return rf(ctx, additionalDiags, components...) + } + if rf, ok := ret.Get(0).(func(context.Context, []cproto.AdditionalDiagnosticRequest, ...client.DiagnosticComponentRequest) []client.DiagnosticComponentResult); ok { + r0 = rf(ctx, additionalDiags, components...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]client.DiagnosticComponentResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, []cproto.AdditionalDiagnosticRequest, ...client.DiagnosticComponentRequest) error); ok { + r1 = rf(ctx, additionalDiags, components...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_DiagnosticComponents_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DiagnosticComponents' +type Client_DiagnosticComponents_Call struct { + *mock.Call +} + +// DiagnosticComponents is a helper method to define mock.On call +// - ctx context.Context +// - additionalDiags []cproto.AdditionalDiagnosticRequest +// - components ...client.DiagnosticComponentRequest +func (_e *Client_Expecter) DiagnosticComponents(ctx interface{}, additionalDiags interface{}, components ...interface{}) *Client_DiagnosticComponents_Call { + return &Client_DiagnosticComponents_Call{Call: _e.mock.On("DiagnosticComponents", + append([]interface{}{ctx, additionalDiags}, components...)...)} +} + +func (_c *Client_DiagnosticComponents_Call) Run(run func(ctx context.Context, additionalDiags []cproto.AdditionalDiagnosticRequest, components ...client.DiagnosticComponentRequest)) *Client_DiagnosticComponents_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]client.DiagnosticComponentRequest, len(args)-2) + for i, a := range args[2:] { + if a != nil { + variadicArgs[i] = a.(client.DiagnosticComponentRequest) + } + } + run(args[0].(context.Context), args[1].([]cproto.AdditionalDiagnosticRequest), variadicArgs...) + }) + return _c +} + +func (_c *Client_DiagnosticComponents_Call) Return(_a0 []client.DiagnosticComponentResult, _a1 error) *Client_DiagnosticComponents_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_DiagnosticComponents_Call) RunAndReturn(run func(context.Context, []cproto.AdditionalDiagnosticRequest, ...client.DiagnosticComponentRequest) ([]client.DiagnosticComponentResult, error)) *Client_DiagnosticComponents_Call { + _c.Call.Return(run) + return _c +} + +// DiagnosticUnits provides a mock function with given fields: ctx, units +func (_m *Client) DiagnosticUnits(ctx context.Context, units ...client.DiagnosticUnitRequest) ([]client.DiagnosticUnitResult, error) { + _va := make([]interface{}, len(units)) + for _i := range units { + _va[_i] = units[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 []client.DiagnosticUnitResult + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, ...client.DiagnosticUnitRequest) ([]client.DiagnosticUnitResult, error)); ok { + return rf(ctx, units...) + } + if rf, ok := ret.Get(0).(func(context.Context, ...client.DiagnosticUnitRequest) []client.DiagnosticUnitResult); ok { + r0 = rf(ctx, units...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]client.DiagnosticUnitResult) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, ...client.DiagnosticUnitRequest) error); ok { + r1 = rf(ctx, units...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_DiagnosticUnits_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DiagnosticUnits' +type Client_DiagnosticUnits_Call struct { + *mock.Call +} + +// DiagnosticUnits is a helper method to define mock.On call +// - ctx context.Context +// - units ...client.DiagnosticUnitRequest +func (_e *Client_Expecter) DiagnosticUnits(ctx interface{}, units ...interface{}) *Client_DiagnosticUnits_Call { + return &Client_DiagnosticUnits_Call{Call: _e.mock.On("DiagnosticUnits", + append([]interface{}{ctx}, units...)...)} +} + +func (_c *Client_DiagnosticUnits_Call) Run(run func(ctx context.Context, units ...client.DiagnosticUnitRequest)) *Client_DiagnosticUnits_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]client.DiagnosticUnitRequest, len(args)-1) + for i, a := range args[1:] { + if a != nil { + variadicArgs[i] = a.(client.DiagnosticUnitRequest) + } + } + run(args[0].(context.Context), variadicArgs...) + }) + return _c +} + +func (_c *Client_DiagnosticUnits_Call) Return(_a0 []client.DiagnosticUnitResult, _a1 error) *Client_DiagnosticUnits_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_DiagnosticUnits_Call) RunAndReturn(run func(context.Context, ...client.DiagnosticUnitRequest) ([]client.DiagnosticUnitResult, error)) *Client_DiagnosticUnits_Call { + _c.Call.Return(run) + return _c +} + +// Disconnect provides a mock function with given fields: +func (_m *Client) Disconnect() { + _m.Called() +} + +// Client_Disconnect_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Disconnect' +type Client_Disconnect_Call struct { + *mock.Call +} + +// Disconnect is a helper method to define mock.On call +func (_e *Client_Expecter) Disconnect() *Client_Disconnect_Call { + return &Client_Disconnect_Call{Call: _e.mock.On("Disconnect")} +} + +func (_c *Client_Disconnect_Call) Run(run func()) *Client_Disconnect_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *Client_Disconnect_Call) Return() *Client_Disconnect_Call { + _c.Call.Return() + return _c +} + +func (_c *Client_Disconnect_Call) RunAndReturn(run func()) *Client_Disconnect_Call { + _c.Call.Return(run) + return _c +} + +// Restart provides a mock function with given fields: ctx +func (_m *Client) Restart(ctx context.Context) error { + ret := _m.Called(ctx) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context) error); ok { + r0 = rf(ctx) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Client_Restart_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Restart' +type Client_Restart_Call struct { + *mock.Call +} + +// Restart is a helper method to define mock.On call +// - ctx context.Context +func (_e *Client_Expecter) Restart(ctx interface{}) *Client_Restart_Call { + return &Client_Restart_Call{Call: _e.mock.On("Restart", ctx)} +} + +func (_c *Client_Restart_Call) Run(run func(ctx context.Context)) *Client_Restart_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Client_Restart_Call) Return(_a0 error) *Client_Restart_Call { + _c.Call.Return(_a0) + return _c +} + +func (_c *Client_Restart_Call) RunAndReturn(run func(context.Context) error) *Client_Restart_Call { + _c.Call.Return(run) + return _c +} + +// State provides a mock function with given fields: ctx +func (_m *Client) State(ctx context.Context) (*client.AgentState, error) { + ret := _m.Called(ctx) + + var r0 *client.AgentState + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*client.AgentState, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) *client.AgentState); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*client.AgentState) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_State_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'State' +type Client_State_Call struct { + *mock.Call +} + +// State is a helper method to define mock.On call +// - ctx context.Context +func (_e *Client_Expecter) State(ctx interface{}) *Client_State_Call { + return &Client_State_Call{Call: _e.mock.On("State", ctx)} +} + +func (_c *Client_State_Call) Run(run func(ctx context.Context)) *Client_State_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Client_State_Call) Return(_a0 *client.AgentState, _a1 error) *Client_State_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_State_Call) RunAndReturn(run func(context.Context) (*client.AgentState, error)) *Client_State_Call { + _c.Call.Return(run) + return _c +} + +// StateWatch provides a mock function with given fields: ctx +func (_m *Client) StateWatch(ctx context.Context) (client.ClientStateWatch, error) { + ret := _m.Called(ctx) + + var r0 client.ClientStateWatch + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (client.ClientStateWatch, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) client.ClientStateWatch); ok { + r0 = rf(ctx) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(client.ClientStateWatch) + } + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_StateWatch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StateWatch' +type Client_StateWatch_Call struct { + *mock.Call +} + +// StateWatch is a helper method to define mock.On call +// - ctx context.Context +func (_e *Client_Expecter) StateWatch(ctx interface{}) *Client_StateWatch_Call { + return &Client_StateWatch_Call{Call: _e.mock.On("StateWatch", ctx)} +} + +func (_c *Client_StateWatch_Call) Run(run func(ctx context.Context)) *Client_StateWatch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Client_StateWatch_Call) Return(_a0 client.ClientStateWatch, _a1 error) *Client_StateWatch_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_StateWatch_Call) RunAndReturn(run func(context.Context) (client.ClientStateWatch, error)) *Client_StateWatch_Call { + _c.Call.Return(run) + return _c +} + +// Upgrade provides a mock function with given fields: ctx, version, sourceURI, skipVerify, skipDefaultPgp, pgpBytes +func (_m *Client) Upgrade(ctx context.Context, version string, sourceURI string, skipVerify bool, skipDefaultPgp bool, pgpBytes ...string) (string, error) { + _va := make([]interface{}, len(pgpBytes)) + for _i := range pgpBytes { + _va[_i] = pgpBytes[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, version, sourceURI, skipVerify, skipDefaultPgp) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, bool, ...string) (string, error)); ok { + return rf(ctx, version, sourceURI, skipVerify, skipDefaultPgp, pgpBytes...) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, bool, bool, ...string) string); ok { + r0 = rf(ctx, version, sourceURI, skipVerify, skipDefaultPgp, pgpBytes...) + } else { + r0 = ret.Get(0).(string) + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, bool, bool, ...string) error); ok { + r1 = rf(ctx, version, sourceURI, skipVerify, skipDefaultPgp, pgpBytes...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_Upgrade_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Upgrade' +type Client_Upgrade_Call struct { + *mock.Call +} + +// Upgrade is a helper method to define mock.On call +// - ctx context.Context +// - version string +// - sourceURI string +// - skipVerify bool +// - skipDefaultPgp bool +// - pgpBytes ...string +func (_e *Client_Expecter) Upgrade(ctx interface{}, version interface{}, sourceURI interface{}, skipVerify interface{}, skipDefaultPgp interface{}, pgpBytes ...interface{}) *Client_Upgrade_Call { + return &Client_Upgrade_Call{Call: _e.mock.On("Upgrade", + append([]interface{}{ctx, version, sourceURI, skipVerify, skipDefaultPgp}, pgpBytes...)...)} +} + +func (_c *Client_Upgrade_Call) Run(run func(ctx context.Context, version string, sourceURI string, skipVerify bool, skipDefaultPgp bool, pgpBytes ...string)) *Client_Upgrade_Call { + _c.Call.Run(func(args mock.Arguments) { + variadicArgs := make([]string, len(args)-5) + for i, a := range args[5:] { + if a != nil { + variadicArgs[i] = a.(string) + } + } + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(bool), args[4].(bool), variadicArgs...) + }) + return _c +} + +func (_c *Client_Upgrade_Call) Return(_a0 string, _a1 error) *Client_Upgrade_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_Upgrade_Call) RunAndReturn(run func(context.Context, string, string, bool, bool, ...string) (string, error)) *Client_Upgrade_Call { + _c.Call.Return(run) + return _c +} + +// Version provides a mock function with given fields: ctx +func (_m *Client) Version(ctx context.Context) (client.Version, error) { + ret := _m.Called(ctx) + + var r0 client.Version + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (client.Version, error)); ok { + return rf(ctx) + } + if rf, ok := ret.Get(0).(func(context.Context) client.Version); ok { + r0 = rf(ctx) + } else { + r0 = ret.Get(0).(client.Version) + } + + if rf, ok := ret.Get(1).(func(context.Context) error); ok { + r1 = rf(ctx) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Client_Version_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Version' +type Client_Version_Call struct { + *mock.Call +} + +// Version is a helper method to define mock.On call +// - ctx context.Context +func (_e *Client_Expecter) Version(ctx interface{}) *Client_Version_Call { + return &Client_Version_Call{Call: _e.mock.On("Version", ctx)} +} + +func (_c *Client_Version_Call) Run(run func(ctx context.Context)) *Client_Version_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context)) + }) + return _c +} + +func (_c *Client_Version_Call) Return(_a0 client.Version, _a1 error) *Client_Version_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *Client_Version_Call) RunAndReturn(run func(context.Context) (client.Version, error)) *Client_Version_Call { + _c.Call.Return(run) + return _c +} + +type mockConstructorTestingTNewClient interface { + mock.TestingT + Cleanup(func()) +} + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewClient(t mockConstructorTestingTNewClient) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/sonar-project.properties b/sonar-project.properties index d6b0f7bec1d..44136f7b9da 100644 --- a/sonar-project.properties +++ b/sonar-project.properties @@ -2,7 +2,7 @@ sonar.projectKey=elastic_elastic-agent_AYluowg0xMq8P7b4moiZ sonar.host.url=https://sonar.elastic.dev sonar.sources=. -sonar.exclusions=**/*_test.go, .git/**, dev-tools/**, /magefile.go, changelog/**, _meta/**, deploy/**, docs/**, img/**, specs/**, pkg/testing/**, pkg/component/fake/**m **/mocks/*.go +sonar.exclusions=**/*_test.go, .git/**, dev-tools/**, /magefile.go, changelog/**, _meta/**, deploy/**, docs/**, img/**, specs/**, pkg/testing/**, pkg/component/fake/**, **/mocks/*.go sonar.tests=. sonar.test.inclusions=**/*_test.go diff --git a/testing/integration/upgrade_test.go b/testing/integration/upgrade_test.go index d624a5f1dcb..a4e27f9313a 100644 --- a/testing/integration/upgrade_test.go +++ b/testing/integration/upgrade_test.go @@ -475,55 +475,11 @@ func testStandaloneUpgrade( t.Logf("Agent installation output: %q", string(output)) require.NoError(t, err) - c := f.Client() - - err = c.Connect(ctx) - require.NoError(t, err, "error connecting client to agent") - defer c.Disconnect() - require.Eventually(t, func() bool { return checkAgentHealthAndVersion(t, ctx, f, parsedFromVersion.CoreVersion(), parsedFromVersion.IsSnapshot(), "") }, 2*time.Minute, 10*time.Second, "Agent never became healthy") - t.Logf("Upgrading from version %q to version %q", parsedFromVersion, parsedUpgradeVersion) - - upgradeCmdArgs := []string{"upgrade", parsedUpgradeVersion.String()} - - useLocalPackage := allowLocalPackage && version_8_7_0.Less(*parsedFromVersion) - if useLocalPackage { - // if we are upgrading from a version > 8.7.0 (min version to skip signature verification) we pass : - // - a file:// sourceURI pointing the agent package under test - // - flag --skip-verify to bypass pgp signature verification (we don't produce signatures for PR/main builds) - tof, err := define.NewFixture(t, parsedUpgradeVersion.String()) - require.NoError(t, err) - - srcPkg, err := tof.SrcPackage(ctx) - require.NoError(t, err) - sourceURI := "file://" + filepath.Dir(srcPkg) - t.Logf("setting sourceURI to : %q", sourceURI) - upgradeCmdArgs = append(upgradeCmdArgs, "--source-uri", sourceURI) - } - if useLocalPackage || skipVerify { - upgradeCmdArgs = append(upgradeCmdArgs, "--skip-verify") - } - - if skipDefaultPgp { - upgradeCmdArgs = append(upgradeCmdArgs, "--skip-default-pgp") - } - - if len(customPgp.PGP) > 0 { - upgradeCmdArgs = append(upgradeCmdArgs, "--pgp", customPgp.PGP) - } - - if len(customPgp.PGPUri) > 0 { - upgradeCmdArgs = append(upgradeCmdArgs, "--pgp-uri", customPgp.PGPUri) - } - - if len(customPgp.PGPPath) > 0 { - upgradeCmdArgs = append(upgradeCmdArgs, "--pgp-path", customPgp.PGPPath) - } - - upgradeTriggerOutput, err := f.Exec(ctx, upgradeCmdArgs) + upgradeTriggerOutput, err := upgradeAgent(ctx, t, f, parsedFromVersion, parsedUpgradeVersion, allowLocalPackage, skipVerify, skipDefaultPgp, customPgp) require.NoErrorf(t, err, "error triggering agent upgrade to version %q, output:\n%s", parsedUpgradeVersion, upgradeTriggerOutput) @@ -534,6 +490,12 @@ func testStandaloneUpgrade( checkUpgradeWatcherRan(t, f, parsedFromVersion) if expectedAgentHashAfterUpgrade != "" { + c := f.Client() + + err = c.Connect(ctx) + require.NoError(t, err, "error connecting client to agent") + defer c.Disconnect() + aVersion, err := c.Version(ctx) assert.NoError(t, err, "error checking version after upgrade") assert.Equal(t, expectedAgentHashAfterUpgrade, aVersion.Commit, "agent commit hash changed after upgrade") @@ -690,10 +652,9 @@ func TestStandaloneUpgradeRetryDownload(t *testing.T) { // We go back TWO minors because sometimes we are in a situation where // the current version has been advanced to the next release (e.g. 8.10.0) // but the version before that (e.g. 8.9.0) hasn't been released yet. - previousVersion, err := upgradeFromVersion.GetPreviousMinor() - require.NoError(t, err) - previousVersion, err = previousVersion.GetPreviousMinor() - require.NoError(t, err) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + previousVersion := twoMinorsPrevious(t, ctx) // For testing the upgrade we actually perform a downgrade upgradeToVersion := previousVersion @@ -703,9 +664,6 @@ func TestStandaloneUpgradeRetryDownload(t *testing.T) { agentFixture, err := define.NewFixture(t, define.Version()) require.NoError(t, err) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - err = agentFixture.Prepare(ctx) require.NoError(t, err, "error preparing agent fixture") @@ -760,7 +718,7 @@ func TestStandaloneUpgradeRetryDownload(t *testing.T) { go func() { wg.Add(1) - err := upgradeAgent(ctx, toVersion, agentFixture, t.Log) + _, err := upgradeAgent(ctx, t, agentFixture, upgradeFromVersion, upgradeToVersion, false, false, false, CustomPGP{}) wg.Done() require.NoError(t, err) @@ -835,15 +793,63 @@ func restoreEtcHosts() error { return cmd.Run() } -func upgradeAgent(ctx context.Context, version string, agentFixture *atesting.Fixture, log func(args ...any)) error { - args := []string{"upgrade", version} - output, err := agentFixture.Exec(ctx, args) - if err != nil { - log("Upgrade command output after error: ", string(output)) - return err +func upgradeAgent( + ctx context.Context, + t *testing.T, + agentFixture *atesting.Fixture, + parsedFromVersion *version.ParsedSemVer, + parsedToVersion *version.ParsedSemVer, + allowLocalPackage bool, + skipVerify bool, + skipDefaultPgp bool, + customPgp CustomPGP, +) ([]byte, error) { + t.Helper() + + c := agentFixture.Client() + + err := c.Connect(ctx) + require.NoError(t, err, "error connecting client to agent") + defer c.Disconnect() + + t.Logf("Upgrading from version %q to version %q", parsedFromVersion, parsedToVersion) + upgradeCmdArgs := []string{"upgrade", parsedToVersion.String()} + + useLocalPackage := allowLocalPackage && version_8_7_0.Less(*parsedFromVersion) + if useLocalPackage { + // if we are upgrading from a version > 8.7.0 (min version to skip signature verification) we pass : + // - a file:// sourceURI pointing the agent package under test + // - flag --skip-verify to bypass pgp signature verification (we don't produce signatures for PR/main builds) + tof, err := define.NewFixture(t, parsedToVersion.String()) + require.NoError(t, err) + + srcPkg, err := tof.SrcPackage(ctx) + require.NoError(t, err) + sourceURI := "file://" + filepath.Dir(srcPkg) + t.Logf("setting sourceURI to : %q", sourceURI) + upgradeCmdArgs = append(upgradeCmdArgs, "--source-uri", sourceURI) + } + if useLocalPackage || skipVerify { + upgradeCmdArgs = append(upgradeCmdArgs, "--skip-verify") + } + + if skipDefaultPgp { + upgradeCmdArgs = append(upgradeCmdArgs, "--skip-default-pgp") + } + + if len(customPgp.PGP) > 0 { + upgradeCmdArgs = append(upgradeCmdArgs, "--pgp", customPgp.PGP) } - return nil + if len(customPgp.PGPUri) > 0 { + upgradeCmdArgs = append(upgradeCmdArgs, "--pgp-uri", customPgp.PGPUri) + } + + if len(customPgp.PGPPath) > 0 { + upgradeCmdArgs = append(upgradeCmdArgs, "--pgp-path", customPgp.PGPPath) + } + + return agentFixture.Exec(ctx, upgradeCmdArgs) } func TestUpgradeBrokenPackageVersion(t *testing.T) { @@ -980,26 +986,7 @@ func TestStandaloneUpgradeFailsStatus(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - // Get available versions from Artifacts API - aac := tools.NewArtifactAPIClient() - versionList, err := aac.GetVersions(ctx) - require.NoError(t, err) - require.NotEmpty(t, versionList.Versions, "Artifact API returned no versions") - - // Determine the version that's TWO versions behind the latest. This is necessary for two reasons: - // 1. We don't want to necessarily use the latest version as it might be the same as the - // local one, which will then cause the invalid input in the Agent test policy (defined further - // below in this test) to come into play with the Agent version we're upgrading from, thus preventing - // it from ever becoming healthy. - // 2. We don't want to necessarily use the version that's one before the latest because sometimes we - // are in a situation where the latest version has been advanced to the next release (e.g. 8.10.0) - // but the version before that (e.g. 8.9.0) hasn't been released yet. - require.GreaterOrEqual(t, len(versionList.Versions), 3) - upgradeToVersionStr := versionList.Versions[len(versionList.Versions)-3] - - upgradeToVersion, err := version.ParseVersion(upgradeToVersionStr) - require.NoError(t, err) - + upgradeToVersion := twoMinorsPrevious(t, ctx) t.Logf("Testing Elastic Agent upgrade from %s to %s...", upgradeFromVersion, upgradeToVersion) agentFixture, err := define.NewFixture(t, define.Version()) @@ -1159,3 +1146,95 @@ func TestStandaloneUpgradeFailsRestart(t *testing.T) { return checkAgentHealthAndVersion(t, ctx, fromF, fromVersionParsed.CoreVersion(), false, "") }, 2*time.Minute, 10*time.Second, "Installed Agent never became healthy") } + +// TestStandaloneUpgradeFailsWhenUpgradeIsInProgress initiates an upgrade for a +// standalone Elastic Agent and, while that upgrade is still in progress, attempts +// to initiate a second upgrade. The test expects Elastic Agent to not allow +// the second upgrade. +func TestStandaloneUpgradeFailsWhenUpgradeIsInProgress(t *testing.T) { + define.Require(t, define.Requirements{ + Local: false, // requires Agent installation + Isolate: false, + Sudo: true, // requires Agent installation + }) + + // For this test we start with a version of Agent that's two minors older + // than the current version and upgrade to the current version. Then we attempt + // upgrading to the current version again, expecting Elastic Agent to disallow + // this second upgrade. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + upgradeFromVersion := twoMinorsPrevious(t, ctx) + + upgradeToVersion, err := version.ParseVersion(define.Version()) + require.NoError(t, err) + + t.Logf("Testing Elastic Agent upgrade from %s to %s...", upgradeFromVersion, upgradeToVersion) + + agentFixture, err := atesting.NewFixture(t, upgradeFromVersion.String()) + require.NoError(t, err) + + err = agentFixture.Prepare(ctx) + require.NoError(t, err, "error preparing agent fixture") + + // Configure Agent with fast watcher configuration. + err = agentFixture.Configure(ctx, []byte(fastWatcherCfg)) + require.NoError(t, err, "error configuring agent fixture") + + t.Log("Install the built Agent") + output, err := tools.InstallStandaloneAgent(agentFixture) + t.Log(string(output)) + require.NoError(t, err) + + require.Eventually(t, func() bool { + return checkAgentHealthAndVersion(t, ctx, agentFixture, upgradeFromVersion.CoreVersion(), upgradeFromVersion.IsSnapshot(), "") + }, 2*time.Minute, 10*time.Second, "Agent never became healthy") + + // Upgrade Elastic Agent via commandline + toVersion := upgradeToVersion.String() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + t.Logf("Upgrading Agent to %s for the first time", toVersion) + _, err := upgradeAgent(ctx, t, agentFixture, upgradeFromVersion, upgradeToVersion, true, true, false, CustomPGP{}) + require.NoError(t, err) + }() + + wg.Wait() + + // Attempt to upgrade Elastic Agent again, while upgrade is still in progress. The + // Upgrade Watcher from the previous upgrade attempt should still be running at this + // point, so Elastic Agent should prevent this second upgrade attempt. + t.Logf("Attempting to upgrade Agent again to %s", toVersion) + output, err = upgradeAgent(ctx, t, agentFixture, upgradeToVersion, upgradeToVersion, true, true, false, CustomPGP{}) + require.NotNil(t, err) + require.Contains(t, string(output), "an upgrade is already in progress; please try again later.") +} + +func twoMinorsPrevious(t *testing.T, ctx context.Context) *version.ParsedSemVer { + t.Helper() + + // Get available versions from Artifacts API + aac := tools.NewArtifactAPIClient() + versionList, err := aac.GetVersions(ctx) + require.NoError(t, err) + require.NotEmpty(t, versionList.Versions, "Artifact API returned no versions") + + // Determine the version that's TWO versions behind the latest. This is necessary for two reasons: + // 1. We don't want to necessarily use the latest version as it might be the same as the + // local one, which will then cause the invalid input in the Agent test policy (defined further + // below in this test) to come into play with the Agent version we're upgrading from, thus preventing + // it from ever becoming healthy. + // 2. We don't want to necessarily use the version that's one before the latest because sometimes we + // are in a situation where the latest version has been advanced to the next release (e.g. 8.10.0) + // but the version before that (e.g. 8.9.0) hasn't been released yet. + require.GreaterOrEqual(t, len(versionList.Versions), 3) + vStr := versionList.Versions[len(versionList.Versions)-3] + + v, err := version.ParseVersion(vStr) + require.NoError(t, err) + + return v +}