Skip to content

Commit

Permalink
Credential code cosmetic cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
gianniLesl committed Jul 3, 2023
1 parent 4bcceb3 commit 42a14ff
Show file tree
Hide file tree
Showing 25 changed files with 62 additions and 90 deletions.
7 changes: 3 additions & 4 deletions agent/ssm/authregister/authregister_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,9 @@ type RegistrationInfo struct {
func NewClientWithConfig(log logger.T, appConfig appconfig.SsmagentConfig, imdsClient iirprovider.IEC2MdsSdkClient, awsConfig aws.Config) IClient {
if imdsClient != nil {
awsConfig.Credentials = credentials.NewCredentials(&iirprovider.IIRRoleProvider{
ExpiryWindow: iirprovider.EarlyExpiryTimeWindow,
Config: &appConfig,
Log: log,
IMDSClient: imdsClient,
Config: &appConfig,
Log: log,
IMDSClient: imdsClient,
})
} else {
awsConfig.Credentials = credentialproviders.GetRemoteCreds()
Expand Down
7 changes: 3 additions & 4 deletions agent/ssm/rsaauth/rsa_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,9 @@ func NewRsaClient(log log.T, appConfig *appconfig.SsmagentConfig, serverId, regi
func NewIirRsaClient(log log.T, appConfig *appconfig.SsmagentConfig, imdsClient iirprovider.IEC2MdsSdkClient, region, encodedPrivateKey string) authtokenrequest.IClient {
awsConfig := deps.AwsConfig(log, *appConfig, "ssm", region)
awsConfig.Credentials = deps.NewCredentials(&iirprovider.IIRRoleProvider{
ExpiryWindow: iirprovider.EarlyExpiryTimeWindow, // Triggers credential refresh, updated on Retrieve()
Config: appConfig,
Log: log,
IMDSClient: imdsClient,
Config: appConfig,
Log: log,
IMDSClient: imdsClient,
})

if appConfig.Ssm.Endpoint != "" {
Expand Down
4 changes: 2 additions & 2 deletions common/identity/availableidentities/ec2/ec2_identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ func (i *Identity) CredentialProvider() credentialproviders.IRemoteProvider {
return i.credentialsProvider
}

// RegisterWithContext registers the EC2 identity with Systems Manager
func (i *Identity) RegisterWithContext(ctx context.Context) error {
// Register registers the EC2 identity with Systems Manager
func (i *Identity) Register(ctx context.Context) error {
region, err := i.RegionWithContext(ctx)
if err != nil {
return fmt.Errorf("unable to get region for identity %w", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestEC2Identity_Register_CancelTest(t *testing.T) {
complete := make(chan struct{})

go func() {
err := identity.RegisterWithContext(ctx)
err := identity.Register(ctx)
assert.Error(t, err)
complete <- struct{}{}
close(complete)
Expand Down
12 changes: 6 additions & 6 deletions common/identity/availableidentities/ec2/ec2_identity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func TestEC2Identity_Register_RegistersEC2InstanceWithSSM_WhenNotRegistered(t *t
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

//Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -387,7 +387,7 @@ func TestEC2Identity_Register_New_WhenAlreadyRegisteredWithOldInstanceId(t *test
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

// Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -442,7 +442,7 @@ func TestEC2Identity_ReRegister_InfoPublicKey_NotBlank(t *testing.T) {
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

// Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -493,7 +493,7 @@ func TestEC2Identity_ReRegister_InfoPublicKey_Blank(t *testing.T) {
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

// Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -529,7 +529,7 @@ func TestEC2Identity_Register_ReturnsRegistrationInfo_WhenAlreadyRegistered(t *t
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

// Assert
assert.NoError(t, err)
Expand Down Expand Up @@ -585,7 +585,7 @@ func TestEC2Identity_Register_ReturnsNil_WhenInstanceAlreadyRegistered(t *testin
}

// Act
err := identity.RegisterWithContext(context.Background())
err := identity.Register(context.Background())

// Assert
assert.NoError(t, err)
Expand Down
11 changes: 1 addition & 10 deletions common/identity/availableidentities/ec2/stubs/provider_stub.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,6 @@ import (
"github.com/aws/aws-sdk-go/aws/credentials"
)

const (
SharedProviderName = "SharedProvider"
NonSharedProviderName = "NonSharedProvider"
)

type ProviderStub struct {
ProviderName string
Profile string
Expand All @@ -36,11 +31,7 @@ func (p *ProviderStub) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}

func (p *ProviderStub) RemoteRetrieve() (credentials.Value, error) {
return p.RemoteRetrieveWithContext(context.Background())
}

func (p *ProviderStub) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) {
func (p *ProviderStub) RemoteRetrieve(ctx context.Context) (credentials.Value, error) {
return credentials.Value{
ProviderName: p.ProviderName,
}, nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ type IRemoteProvider interface {
ShareFile() string
SharesCredentials() bool
CredentialSource() string
RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error)
RemoteRetrieve(ctx context.Context) (credentials.Value, error)
RemoteExpiresAt() time.Time
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@ import (
"sync"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"

"github.com/aws/amazon-ssm-agent/common/runtimeconfig"
"github.com/aws/aws-sdk-go/service/ssm"

"github.com/aws/amazon-ssm-agent/agent/appconfig"
"github.com/aws/amazon-ssm-agent/agent/log"
"github.com/aws/amazon-ssm-agent/agent/sdkutil"
"github.com/aws/amazon-ssm-agent/agent/version"
"github.com/aws/amazon-ssm-agent/common/identity/credentialproviders/ssmec2roleprovider"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/aws/amazon-ssm-agent/common/runtimeconfig"
)

// EC2RoleProvider provides credentials for the agent when on an EC2 instance
Expand Down Expand Up @@ -115,8 +114,8 @@ func (p *EC2RoleProvider) RetrieveWithContext(ctx context.Context) (credentials.
return iprCredentials, nil
}

// RemoteRetrieveWithContext uses network calls to retrieve credentials for EC2 instances
func (p *EC2RoleProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) {
// RemoteRetrieve uses network calls to retrieve credentials for EC2 instances
func (p *EC2RoleProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) {
p.Log.Debug("Attempting to retrieve instance profile role")
if iprCredentials, err := p.iprCredentials(ctx, p.SsmEndpoint); err != nil {
errCode := sdkutil.GetAwsErrorCode(err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,5 +67,5 @@ type IEC2RoleProvider interface {
ShareProfile() string
SharesCredentials() bool
RetrieveWithContext(ctx context.Context) (credentials.Value, error)
RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error)
RemoteRetrieve(ctx context.Context) (credentials.Value, error)
}
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ func TestEC2RoleProvider_RetrieveRemote_ReturnsEmptyCredentials(t *testing.T) {
ec2RoleProvider := arrangeRetrieveEmptyTest(j)

// Act
creds, err := ec2RoleProvider.RemoteRetrieveWithContext(context.Background())
creds, err := ec2RoleProvider.RemoteRetrieve(context.Background())

//Assert
assert.Error(t, err)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@ import (
)

const (
// EarlyExpiryTimeWindow set a short amount of time that will mark the credentials as expired, this can avoid
// calls being made with expired credentials. This value should not be too big that's greater than the default token
// expiry time. For example, the token expires after 30 min and we set it to 40 min which expires the token
// immediately. The value should also not be too small that it should trigger credential rotation before it expires.
EarlyExpiryTimeWindow = 1 * time.Minute

// ProviderName is the role provider name that is returned with credentials
ProviderName = "EC2IdentityRoleProvider"
iirCredentialsPath = "identity-credentials/ec2/security-credentials/ec2-instance"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ func TestRetrieve_ReturnsCredentials(t *testing.T) {
mockIMDSClient.On("GetMetadata", iirCredentialsPath).Return(string(respJSONBytes), nil)

roleProvider := &IIRRoleProvider{
IMDSClient: mockIMDSClient,
ExpiryWindow: EarlyExpiryTimeWindow,
Config: &ssmConfig,
Log: logger,
IMDSClient: mockIMDSClient,
Config: &ssmConfig,
Log: logger,
}

result, err := roleProvider.Retrieve()
Expand Down
4 changes: 2 additions & 2 deletions common/identity/credentialproviders/mocks/IRemoteProvider.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ func shouldRetryAwsRequest(err error) bool {
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *onpremCredentialsProvider) Retrieve() (credentials.Value, error) {
return m.RemoteRetrieveWithContext(context.Background())
return m.RemoteRetrieve(context.Background())
}

// RemoteRetrieveWithContext retrieves OnPrem credentials from the SSM Auth service.
// RemoteRetrieve retrieves OnPrem credentials from the SSM Auth service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *onpremCredentialsProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) {
func (m *onpremCredentialsProvider) RemoteRetrieve(ctx context.Context) (credentials.Value, error) {
var err error
var roleCreds *ssm.RequestManagedInstanceRoleTokenOutput

Expand Down Expand Up @@ -165,13 +165,6 @@ func (m *onpremCredentialsProvider) RemoteRetrieveWithContext(ctx context.Contex
}, nil
}

// RemoteRetrieve retrieves OnPrem credentials from the SSM Auth service.
// Error will be returned if the request fails, or unable to extract
// the desired credentials.
func (m *onpremCredentialsProvider) RemoteRetrieve() (credentials.Value, error) {
return m.RemoteRetrieveWithContext(context.Background())
}

func (m *onpremCredentialsProvider) RemoteExpiresAt() time.Time {
return m.ExpiresAt()
}
Expand Down
1 change: 0 additions & 1 deletion common/identity/credentialproviders/ssmclient/ssmclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (

// ISSMClient defines the functions needed for role providers send health pings to Systems Manager
type ISSMClient interface {
UpdateInstanceInformation(input *ssm.UpdateInstanceInformationInput) (*ssm.UpdateInstanceInformationOutput, error)
UpdateInstanceInformationWithContext(ctx context.Context, input *ssm.UpdateInstanceInformationInput, opts ...request.Option) (*ssm.UpdateInstanceInformationOutput, error)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package ssmec2roleprovider
import (
"context"
"fmt"
"sync"
"time"

"github.com/aws/amazon-ssm-agent/agent/appconfig"
Expand All @@ -41,7 +40,6 @@ var (
getStoredPrivateKey = registration.PrivateKey
getStoredPublicKey = registration.PublicKey
getStoredPrivateKeyType = registration.PrivateKeyType
loadRegistrationLock = &sync.Mutex{}
)

// SSMEC2RoleProvider sends requests for credentials to systems manager signed with AWS SigV4
Expand Down Expand Up @@ -75,8 +73,6 @@ func (p *SSMEC2RoleProvider) isEC2InstanceRegistered() bool {
return false
}

loadRegistrationLock.Lock()
defer loadRegistrationLock.Unlock()
p.registrationInfo = registrationInfo
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,8 @@

package ssmec2roleprovider

const (
// ProviderName is the role provider name that is returned with credentials
ProviderName = "SSMEC2RoleProvider"
)
// ProviderName is the role provider name that is returned with credentials
const ProviderName = "SSMEC2RoleProvider"

// InstanceInfo contains information about current EC2 instance
type InstanceInfo struct {
Expand Down
2 changes: 1 addition & 1 deletion common/identity/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type IAgentIdentity interface {

// Registrar identity registers the agent on startup
type Registrar interface {
RegisterWithContext(context.Context) error
Register(context.Context) error
}
type IInnerIdentityGetter interface {
GetInner() IAgentIdentityInner
Expand Down
4 changes: 2 additions & 2 deletions common/identity/mocks/Registrar.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion common/runtimeconfig/identity_runtimeconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ package runtimeconfig
import (
"encoding/json"
"fmt"
"github.com/cenkalti/backoff/v4"
"time"

rch "github.com/aws/amazon-ssm-agent/common/runtimeconfig/runtimeconfighandler"
"github.com/cenkalti/backoff/v4"
)

const (
Expand Down Expand Up @@ -87,6 +87,12 @@ func (i *identityRuntimeConfigClient) GetConfigWithRetry() (out IdentityRuntimeC
// Attempts GetConfig up to 6 times with exponential backoff
backoffConfig.MaxElapsedTime = time.Second * 4
err = backoff.Retry(func() error {
if configExists, existsError := i.ConfigExists(); err != nil {
return fmt.Errorf("failed to check whether config extists. Err: %w", existsError)
} else if !configExists {
return nil
}

out, err = i.GetConfig()
return err
}, backoffConfig)
Expand Down
2 changes: 1 addition & 1 deletion core/app/credentialrefresher/credentialrefresher.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func getBackoffRetryJitterSleepDuration(retryCount int) time.Duration {
func (c *credentialsRefresher) retrieveCredsWithRetry(ctx context.Context) (credentials.Value, bool) {
retryCount := 0
for {
creds, err := c.provider.RemoteRetrieveWithContext(ctx)
creds, err := c.provider.RemoteRetrieve(ctx)
if err == nil {
return creds, false
}
Expand Down
Loading

0 comments on commit 42a14ff

Please sign in to comment.