diff --git a/agent/ssm/authregister/authregister_client.go b/agent/ssm/authregister/authregister_client.go index 25fab6eee..e304ab67a 100644 --- a/agent/ssm/authregister/authregister_client.go +++ b/agent/ssm/authregister/authregister_client.go @@ -57,10 +57,9 @@ type RegistrationInfo struct { func NewClientWithConfig(log logger.T, appConfig appconfig.SsmagentConfig, imdsClient iirprovider.IEC2MdsSdkClient, awsConfig aws.Config) IClient { if imdsClient != nil { awsConfig.Credentials = credentials.NewCredentials(&iirprovider.IIRRoleProvider{ - ExpiryWindow: iirprovider.EarlyExpiryTimeWindow, - Config: &appConfig, - Log: log, - IMDSClient: imdsClient, + Config: &appConfig, + Log: log, + IMDSClient: imdsClient, }) } else { awsConfig.Credentials = credentialproviders.GetRemoteCreds() diff --git a/agent/ssm/rsaauth/rsa_client.go b/agent/ssm/rsaauth/rsa_client.go index a1bb19e32..994627db3 100644 --- a/agent/ssm/rsaauth/rsa_client.go +++ b/agent/ssm/rsaauth/rsa_client.go @@ -49,10 +49,9 @@ func NewRsaClient(log log.T, appConfig *appconfig.SsmagentConfig, serverId, regi func NewIirRsaClient(log log.T, appConfig *appconfig.SsmagentConfig, imdsClient iirprovider.IEC2MdsSdkClient, region, encodedPrivateKey string) authtokenrequest.IClient { awsConfig := deps.AwsConfig(log, *appConfig, "ssm", region) awsConfig.Credentials = deps.NewCredentials(&iirprovider.IIRRoleProvider{ - ExpiryWindow: iirprovider.EarlyExpiryTimeWindow, // Triggers credential refresh, updated on Retrieve() - Config: appConfig, - Log: log, - IMDSClient: imdsClient, + Config: appConfig, + Log: log, + IMDSClient: imdsClient, }) if appConfig.Ssm.Endpoint != "" { diff --git a/common/identity/availableidentities/ec2/ec2_identity.go b/common/identity/availableidentities/ec2/ec2_identity.go index ec4bca1a8..77ba9e9ee 100644 --- a/common/identity/availableidentities/ec2/ec2_identity.go +++ b/common/identity/availableidentities/ec2/ec2_identity.go @@ -147,8 +147,8 @@ func (i *Identity) CredentialProvider() credentialproviders.IRemoteProvider { return i.credentialsProvider } -// RegisterWithContext registers the EC2 identity with Systems Manager -func (i *Identity) RegisterWithContext(ctx context.Context) error { +// Register registers the EC2 identity with Systems Manager +func (i *Identity) Register(ctx context.Context) error { region, err := i.RegionWithContext(ctx) if err != nil { return fmt.Errorf("unable to get region for identity %w", err) diff --git a/common/identity/availableidentities/ec2/ec2_identity_integ_test.go b/common/identity/availableidentities/ec2/ec2_identity_integ_test.go index a09b16c64..1bde7af28 100644 --- a/common/identity/availableidentities/ec2/ec2_identity_integ_test.go +++ b/common/identity/availableidentities/ec2/ec2_identity_integ_test.go @@ -71,7 +71,7 @@ func TestEC2Identity_Register_CancelTest(t *testing.T) { complete := make(chan struct{}) go func() { - err := identity.RegisterWithContext(ctx) + err := identity.Register(ctx) assert.Error(t, err) complete <- struct{}{} close(complete) diff --git a/common/identity/availableidentities/ec2/ec2_identity_test.go b/common/identity/availableidentities/ec2/ec2_identity_test.go index be1d45b73..c88ce85ae 100644 --- a/common/identity/availableidentities/ec2/ec2_identity_test.go +++ b/common/identity/availableidentities/ec2/ec2_identity_test.go @@ -339,7 +339,7 @@ func TestEC2Identity_Register_RegistersEC2InstanceWithSSM_WhenNotRegistered(t *t } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) //Assert assert.NoError(t, err) @@ -387,7 +387,7 @@ func TestEC2Identity_Register_New_WhenAlreadyRegisteredWithOldInstanceId(t *test } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) // Assert assert.NoError(t, err) @@ -442,7 +442,7 @@ func TestEC2Identity_ReRegister_InfoPublicKey_NotBlank(t *testing.T) { } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) // Assert assert.NoError(t, err) @@ -493,7 +493,7 @@ func TestEC2Identity_ReRegister_InfoPublicKey_Blank(t *testing.T) { } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) // Assert assert.NoError(t, err) @@ -529,7 +529,7 @@ func TestEC2Identity_Register_ReturnsRegistrationInfo_WhenAlreadyRegistered(t *t } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) // Assert assert.NoError(t, err) @@ -585,7 +585,7 @@ func TestEC2Identity_Register_ReturnsNil_WhenInstanceAlreadyRegistered(t *testin } // Act - err := identity.RegisterWithContext(context.Background()) + err := identity.Register(context.Background()) // Assert assert.NoError(t, err) diff --git a/common/identity/availableidentities/ec2/stubs/provider_stub.go b/common/identity/availableidentities/ec2/stubs/provider_stub.go index 13b10854d..f474cf66d 100644 --- a/common/identity/availableidentities/ec2/stubs/provider_stub.go +++ b/common/identity/availableidentities/ec2/stubs/provider_stub.go @@ -8,11 +8,6 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" ) -const ( - SharedProviderName = "SharedProvider" - NonSharedProviderName = "NonSharedProvider" -) - type ProviderStub struct { ProviderName string Profile string @@ -36,11 +31,7 @@ func (p *ProviderStub) Retrieve() (credentials.Value, error) { return p.RetrieveWithContext(context.Background()) } -func (p *ProviderStub) RemoteRetrieve() (credentials.Value, error) { - return p.RemoteRetrieveWithContext(context.Background()) -} - -func (p *ProviderStub) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) { +func (p *ProviderStub) RemoteRetrieve(ctx context.Context) (credentials.Value, error) { return credentials.Value{ ProviderName: p.ProviderName, }, nil diff --git a/common/identity/credentialproviders/credential_provider_deps.go b/common/identity/credentialproviders/credential_provider_deps.go index 1fe54b5e7..628ab8310 100644 --- a/common/identity/credentialproviders/credential_provider_deps.go +++ b/common/identity/credentialproviders/credential_provider_deps.go @@ -28,6 +28,6 @@ type IRemoteProvider interface { ShareFile() string SharesCredentials() bool CredentialSource() string - RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) + RemoteRetrieve(ctx context.Context) (credentials.Value, error) RemoteExpiresAt() time.Time } diff --git a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider.go b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider.go index 05d69865e..dda27b685 100644 --- a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider.go +++ b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider.go @@ -21,19 +21,18 @@ import ( "sync" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" - - "github.com/aws/amazon-ssm-agent/common/runtimeconfig" + "github.com/aws/aws-sdk-go/service/ssm" "github.com/aws/amazon-ssm-agent/agent/appconfig" "github.com/aws/amazon-ssm-agent/agent/log" "github.com/aws/amazon-ssm-agent/agent/sdkutil" "github.com/aws/amazon-ssm-agent/agent/version" "github.com/aws/amazon-ssm-agent/common/identity/credentialproviders/ssmec2roleprovider" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/amazon-ssm-agent/common/runtimeconfig" ) // EC2RoleProvider provides credentials for the agent when on an EC2 instance @@ -115,8 +114,8 @@ func (p *EC2RoleProvider) RetrieveWithContext(ctx context.Context) (credentials. return iprCredentials, nil } -// RemoteRetrieveWithContext uses network calls to retrieve credentials for EC2 instances -func (p *EC2RoleProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// RemoteRetrieve uses network calls to retrieve credentials for EC2 instances +func (p *EC2RoleProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) { p.Log.Debug("Attempting to retrieve instance profile role") if iprCredentials, err := p.iprCredentials(ctx, p.SsmEndpoint); err != nil { errCode := sdkutil.GetAwsErrorCode(err) diff --git a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_deps.go b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_deps.go index 1f5ddcaac..f74f22ba1 100644 --- a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_deps.go +++ b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_deps.go @@ -67,5 +67,5 @@ type IEC2RoleProvider interface { ShareProfile() string SharesCredentials() bool RetrieveWithContext(ctx context.Context) (credentials.Value, error) - RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) + RemoteRetrieve(ctx context.Context) (credentials.Value, error) } diff --git a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_test.go b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_test.go index ff67e8d42..8f835ccd6 100644 --- a/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_test.go +++ b/common/identity/credentialproviders/ec2roleprovider/ec2_role_provider_test.go @@ -383,7 +383,7 @@ func TestEC2RoleProvider_RetrieveRemote_ReturnsEmptyCredentials(t *testing.T) { ec2RoleProvider := arrangeRetrieveEmptyTest(j) // Act - creds, err := ec2RoleProvider.RemoteRetrieveWithContext(context.Background()) + creds, err := ec2RoleProvider.RemoteRetrieve(context.Background()) //Assert assert.Error(t, err) diff --git a/common/identity/credentialproviders/ec2roleprovider/mocks/IEC2RoleProvider.go b/common/identity/credentialproviders/ec2roleprovider/mocks/IEC2RoleProvider.go index 84b17204c..50629e6d2 100644 --- a/common/identity/credentialproviders/ec2roleprovider/mocks/IEC2RoleProvider.go +++ b/common/identity/credentialproviders/ec2roleprovider/mocks/IEC2RoleProvider.go @@ -90,8 +90,8 @@ func (_m *IEC2RoleProvider) RemoteExpiresAt() time.Time { return r0 } -// RemoteRetrieveWithContext provides a mock function with given fields: ctx -func (_m *IEC2RoleProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// RemoteRetrieve provides a mock function with given fields: ctx +func (_m *IEC2RoleProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) { ret := _m.Called(ctx) var r0 credentials.Value diff --git a/common/identity/credentialproviders/iirprovider/iir_role_provider_deps.go b/common/identity/credentialproviders/iirprovider/iir_role_provider_deps.go index 8ab0f715d..7d6125ef9 100644 --- a/common/identity/credentialproviders/iirprovider/iir_role_provider_deps.go +++ b/common/identity/credentialproviders/iirprovider/iir_role_provider_deps.go @@ -19,12 +19,6 @@ import ( ) const ( - // EarlyExpiryTimeWindow set a short amount of time that will mark the credentials as expired, this can avoid - // calls being made with expired credentials. This value should not be too big that's greater than the default token - // expiry time. For example, the token expires after 30 min and we set it to 40 min which expires the token - // immediately. The value should also not be too small that it should trigger credential rotation before it expires. - EarlyExpiryTimeWindow = 1 * time.Minute - // ProviderName is the role provider name that is returned with credentials ProviderName = "EC2IdentityRoleProvider" iirCredentialsPath = "identity-credentials/ec2/security-credentials/ec2-instance" diff --git a/common/identity/credentialproviders/iirprovider/iir_role_provider_test.go b/common/identity/credentialproviders/iirprovider/iir_role_provider_test.go index 663759b03..41d3672cb 100644 --- a/common/identity/credentialproviders/iirprovider/iir_role_provider_test.go +++ b/common/identity/credentialproviders/iirprovider/iir_role_provider_test.go @@ -58,10 +58,9 @@ func TestRetrieve_ReturnsCredentials(t *testing.T) { mockIMDSClient.On("GetMetadata", iirCredentialsPath).Return(string(respJSONBytes), nil) roleProvider := &IIRRoleProvider{ - IMDSClient: mockIMDSClient, - ExpiryWindow: EarlyExpiryTimeWindow, - Config: &ssmConfig, - Log: logger, + IMDSClient: mockIMDSClient, + Config: &ssmConfig, + Log: logger, } result, err := roleProvider.Retrieve() diff --git a/common/identity/credentialproviders/mocks/IRemoteProvider.go b/common/identity/credentialproviders/mocks/IRemoteProvider.go index 9c82ce9b3..c72c49499 100644 --- a/common/identity/credentialproviders/mocks/IRemoteProvider.go +++ b/common/identity/credentialproviders/mocks/IRemoteProvider.go @@ -59,8 +59,8 @@ func (_m *IRemoteProvider) RemoteExpiresAt() time.Time { return r0 } -// RemoteRetrieveWithContext provides a mock function with given fields: ctx -func (_m *IRemoteProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) { +// RemoteRetrieve provides a mock function with given fields: ctx +func (_m *IRemoteProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) { ret := _m.Called(ctx) var r0 credentials.Value diff --git a/common/identity/credentialproviders/onpremprovider/role_provider.go b/common/identity/credentialproviders/onpremprovider/role_provider.go index a1abd244d..600699b21 100644 --- a/common/identity/credentialproviders/onpremprovider/role_provider.go +++ b/common/identity/credentialproviders/onpremprovider/role_provider.go @@ -92,13 +92,13 @@ func shouldRetryAwsRequest(err error) bool { // Error will be returned if the request fails, or unable to extract // the desired credentials. func (m *onpremCredentialsProvider) Retrieve() (credentials.Value, error) { - return m.RemoteRetrieveWithContext(context.Background()) + return m.RemoteRetrieve(context.Background()) } -// RemoteRetrieveWithContext retrieves OnPrem credentials from the SSM Auth service. +// RemoteRetrieve retrieves OnPrem credentials from the SSM Auth service. // Error will be returned if the request fails, or unable to extract // the desired credentials. -func (m *onpremCredentialsProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) { +func (m *onpremCredentialsProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) { var err error var roleCreds *ssm.RequestManagedInstanceRoleTokenOutput @@ -165,13 +165,6 @@ func (m *onpremCredentialsProvider) RemoteRetrieveWithContext(ctx context.Contex }, nil } -// RemoteRetrieve retrieves OnPrem credentials from the SSM Auth service. -// Error will be returned if the request fails, or unable to extract -// the desired credentials. -func (m *onpremCredentialsProvider) RemoteRetrieve() (credentials.Value, error) { - return m.RemoteRetrieveWithContext(context.Background()) -} - func (m *onpremCredentialsProvider) RemoteExpiresAt() time.Time { return m.ExpiresAt() } diff --git a/common/identity/credentialproviders/ssmclient/ssmclient.go b/common/identity/credentialproviders/ssmclient/ssmclient.go index 0bbe57b93..913146ae6 100644 --- a/common/identity/credentialproviders/ssmclient/ssmclient.go +++ b/common/identity/credentialproviders/ssmclient/ssmclient.go @@ -29,7 +29,6 @@ import ( // ISSMClient defines the functions needed for role providers send health pings to Systems Manager type ISSMClient interface { - UpdateInstanceInformation(input *ssm.UpdateInstanceInformationInput) (*ssm.UpdateInstanceInformationOutput, error) UpdateInstanceInformationWithContext(ctx context.Context, input *ssm.UpdateInstanceInformationInput, opts ...request.Option) (*ssm.UpdateInstanceInformationOutput, error) } diff --git a/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider.go b/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider.go index a7fe6b493..910c46e17 100644 --- a/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider.go +++ b/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider.go @@ -17,7 +17,6 @@ package ssmec2roleprovider import ( "context" "fmt" - "sync" "time" "github.com/aws/amazon-ssm-agent/agent/appconfig" @@ -41,7 +40,6 @@ var ( getStoredPrivateKey = registration.PrivateKey getStoredPublicKey = registration.PublicKey getStoredPrivateKeyType = registration.PrivateKeyType - loadRegistrationLock = &sync.Mutex{} ) // SSMEC2RoleProvider sends requests for credentials to systems manager signed with AWS SigV4 @@ -75,8 +73,6 @@ func (p *SSMEC2RoleProvider) isEC2InstanceRegistered() bool { return false } - loadRegistrationLock.Lock() - defer loadRegistrationLock.Unlock() p.registrationInfo = registrationInfo } diff --git a/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider_deps.go b/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider_deps.go index 5571ffbff..ca9f3ca0d 100644 --- a/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider_deps.go +++ b/common/identity/credentialproviders/ssmec2roleprovider/ssm_ec2_role_provider_deps.go @@ -14,10 +14,8 @@ package ssmec2roleprovider -const ( - // ProviderName is the role provider name that is returned with credentials - ProviderName = "SSMEC2RoleProvider" -) +// ProviderName is the role provider name that is returned with credentials +const ProviderName = "SSMEC2RoleProvider" // InstanceInfo contains information about current EC2 instance type InstanceInfo struct { diff --git a/common/identity/interface.go b/common/identity/interface.go index afe4396ee..e3ae42fb7 100644 --- a/common/identity/interface.go +++ b/common/identity/interface.go @@ -34,7 +34,7 @@ type IAgentIdentity interface { // Registrar identity registers the agent on startup type Registrar interface { - RegisterWithContext(context.Context) error + Register(context.Context) error } type IInnerIdentityGetter interface { GetInner() IAgentIdentityInner diff --git a/common/identity/mocks/Registrar.go b/common/identity/mocks/Registrar.go index 4a9eed9ca..70c0f9f8e 100644 --- a/common/identity/mocks/Registrar.go +++ b/common/identity/mocks/Registrar.go @@ -13,8 +13,8 @@ type Registrar struct { mock.Mock } -// RegisterWithContext provides a mock function with given fields: _a0 -func (_m *Registrar) RegisterWithContext(_a0 context.Context) error { +// Register provides a mock function with given fields: _a0 +func (_m *Registrar) Register(_a0 context.Context) error { ret := _m.Called(_a0) var r0 error diff --git a/common/runtimeconfig/identity_runtimeconfig.go b/common/runtimeconfig/identity_runtimeconfig.go index 1eb1fdca5..f628224d3 100644 --- a/common/runtimeconfig/identity_runtimeconfig.go +++ b/common/runtimeconfig/identity_runtimeconfig.go @@ -16,10 +16,10 @@ package runtimeconfig import ( "encoding/json" "fmt" - "github.com/cenkalti/backoff/v4" "time" rch "github.com/aws/amazon-ssm-agent/common/runtimeconfig/runtimeconfighandler" + "github.com/cenkalti/backoff/v4" ) const ( @@ -87,6 +87,12 @@ func (i *identityRuntimeConfigClient) GetConfigWithRetry() (out IdentityRuntimeC // Attempts GetConfig up to 6 times with exponential backoff backoffConfig.MaxElapsedTime = time.Second * 4 err = backoff.Retry(func() error { + if configExists, existsError := i.ConfigExists(); err != nil { + return fmt.Errorf("failed to check whether config extists. Err: %w", existsError) + } else if !configExists { + return nil + } + out, err = i.GetConfig() return err }, backoffConfig) diff --git a/core/app/credentialrefresher/credentialrefresher.go b/core/app/credentialrefresher/credentialrefresher.go index 991e1d092..7cedbb01e 100644 --- a/core/app/credentialrefresher/credentialrefresher.go +++ b/core/app/credentialrefresher/credentialrefresher.go @@ -198,7 +198,7 @@ func getBackoffRetryJitterSleepDuration(retryCount int) time.Duration { func (c *credentialsRefresher) retrieveCredsWithRetry(ctx context.Context) (credentials.Value, bool) { retryCount := 0 for { - creds, err := c.provider.RemoteRetrieveWithContext(ctx) + creds, err := c.provider.RemoteRetrieve(ctx) if err == nil { return creds, false } diff --git a/core/app/credentialrefresher/credentialrefresher_test.go b/core/app/credentialrefresher/credentialrefresher_test.go index db6b7e822..56bb4f01d 100644 --- a/core/app/credentialrefresher/credentialrefresher_test.go +++ b/core/app/credentialrefresher/credentialrefresher_test.go @@ -48,7 +48,7 @@ var ( func init() { newSharedCredentials = func(_, _ string) *credentials.Credentials { provider := &credentialmocks.Provider{} - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, nil).Once() + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, nil).Once() provider.On("RemoteExpiresAt").Return(time.Now().Add(1 * time.Hour)).Once() provider.On("ShareFile").Return("", nil).Times(2) provider.On("CredentialSource").Return("SSM").Times(3) @@ -263,7 +263,7 @@ func Test_credentialsRefresher_credentialRefresherRoutine_CredentialsNotExpired_ } provider := &credentialmocks.IRemoteProvider{} - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(func(context.Context) credentials.Value { return credentials.Value{} }, func(context.Context) error { + provider.On("RemoteRetrieve", mock.Anything).Return(func(context.Context) credentials.Value { return credentials.Value{} }, func(context.Context) error { // Sleep here because we know that if we reach this point and have not got message in credentialsReadyChan, the time is set correctly time.Sleep(time.Second) return fmt.Errorf("SomeRetrieveErr") @@ -318,7 +318,7 @@ func Test_credentialsRefresher_credentialRefresherRoutine_CredentialsExist_CallS provider := &credentialmocks.IRemoteProvider{} provider.On("Retrieve").Return(credentials.Value{}, nil).Repeatability = 0 - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, nil).Repeatability = 0 + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, nil).Repeatability = 0 provider.On("RemoteExpiresAt").Return(time.Now().Add(1 * time.Hour)).Repeatability = 0 provider.On("ShareFile").Return("SomeShareFile", nil).Repeatability = 0 provider.On("CredentialSource").Return("SSM").Repeatability = 0 @@ -436,7 +436,7 @@ func Test_credentialsRefresher_credentialRefresherRoutine_Purge(t *testing.T) { runtimeConfigClient.On("SaveConfig", mock.Anything).Return(nil).Once() provider := &credentialmocks.IRemoteProvider{} provider.On("ShareFile").Return(tc.newShareFileLocation, nil).Once() - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, nil).Once() + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, nil).Once() provider.On("RemoteExpiresAt").Return(time.Now().Add(1 * time.Hour)).Once() provider.On("CredentialSource").Return("").Once() @@ -526,7 +526,7 @@ func Test_credentialsRefresher_credentialRefresherRoutine_CredentialsDontExist(t provider := &credentialmocks.IRemoteProvider{} provider.On("ShareFile").Return("SomeShareFile", nil).Times(2) provider.On("Retrieve").Return(credentials.Value{}, fmt.Errorf("share file doesn't exist")).Once() - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, nil).Once() + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, nil).Once() provider.On("RemoteExpiresAt").Return(time.Now().Add(1 * time.Hour)).Once() provider.On("CredentialSource").Return("SSM").Once() @@ -585,7 +585,7 @@ func (a awsTestError) Code() string { return a.errCode } func Test_credentialsRefresher_retrieveCredsWithRetry_NonActionableErr(t *testing.T) { for _, awsErr := range []error{awsTestError{"AccessDeniedException"}, awsTestError{"MachineFingerprintDoesNotMatch"}} { provider := &credentialmocks.IRemoteProvider{} - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, awsErr).Once() + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, awsErr).Once() var timeAfterParamVal time.Duration c := &credentialsRefresher{ @@ -626,9 +626,9 @@ func Test_credentialsRefresher_retrieveCredsWithRetry_NonActionableErr(t *testin func Test_credentialsRefresher_retrieveCredsWithRetry_Retry2000TimesNoExitUntilSuccess(t *testing.T) { provider := &credentialmocks.IRemoteProvider{} - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, awsTestError{"PotentiallyRecoverableAWSError"}).Times(1000) - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, fmt.Errorf("SomeRandomNonAwsErr")).Times(1000) - provider.On("RemoteRetrieveWithContext", mock.Anything).Return(credentials.Value{}, nil).Once() + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, awsTestError{"PotentiallyRecoverableAWSError"}).Times(1000) + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, fmt.Errorf("SomeRandomNonAwsErr")).Times(1000) + provider.On("RemoteRetrieve", mock.Anything).Return(credentials.Value{}, nil).Once() numSleeps := 0 c := &credentialsRefresher{ diff --git a/core/app/registrar/registrar.go b/core/app/registrar/registrar.go index a5e510c30..0494fa733 100644 --- a/core/app/registrar/registrar.go +++ b/core/app/registrar/registrar.go @@ -69,7 +69,7 @@ func NewRetryableRegistrar(agentCtx agentCtx.ICoreAgentContext) *RetryableRegist func (r *RetryableRegistrar) Start() error { r.log.Info("Starting registrar module") - r.isRegistrarRunning = true + r.setIsRegistrarRunning(true) go r.RegisterWithRetry() return nil } @@ -113,8 +113,7 @@ func (r *RetryableRegistrar) RegisterWithRetry() { } }() - errChan <- r.identityRegistrar.RegisterWithContext(ctx) - defer close(errChan) + errChan <- r.identityRegistrar.Register(ctx) }() select { case err := <-errChan: @@ -132,7 +131,7 @@ func (r *RetryableRegistrar) RegisterWithRetry() { case <-r.stopRegistrarChan: cancel() r.log.Info("Stopping registrar") - r.getIsRegistrarRunning() + r.setIsRegistrarRunning(false) r.log.Flush() return } diff --git a/core/app/registrar/registrar_test.go b/core/app/registrar/registrar_test.go index 813044dd4..1bdcab432 100644 --- a/core/app/registrar/registrar_test.go +++ b/core/app/registrar/registrar_test.go @@ -15,7 +15,7 @@ import ( func TestRetryableRegistrar_RegisterWithRetry_Success(t *testing.T) { // Arrange identityRegistrar := &identitymocks.Registrar{} - identityRegistrar.On("RegisterWithContext", mock.Anything).Return(nil) + identityRegistrar.On("Register", mock.Anything).Return(nil) timeAfterFunc := func(duration time.Duration) <-chan time.Time { assert.Fail(t, "expected no registration retry or sleep")