Skip to content

Commit

Permalink
Added RetrieveWithContext to ec2 and onprem role providers
Browse files Browse the repository at this point in the history
  • Loading branch information
gianniLesl committed Jul 3, 2023
1 parent 08e0bbf commit 5996b37
Show file tree
Hide file tree
Showing 18 changed files with 198 additions and 68 deletions.
4 changes: 2 additions & 2 deletions agent/ssm/authtokenrequest/authtokenrequest_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (
func TestSSMAuthTokenService_RequestManagedInstanceRoleToken_Success(t *testing.T) {
sdk := &mocks.ISsmSdk{}
response := &ssm.RequestManagedInstanceRoleTokenOutput{}
sdk.On("RequestManagedInstanceRoleToken", mock.Anything).Return(response, nil)
sdk.On("RequestManagedInstanceRoleTokenWithContext", mock.Anything, mock.Anything).Return(response, nil)
authTokenService := NewClient(sdk)
result, err := authTokenService.RequestManagedInstanceRoleToken("SomeFingerprint")
assert.NoError(t, err)
Expand All @@ -36,7 +36,7 @@ func TestSSMAuthTokenService_RequestManagedInstanceRoleToken_Success(t *testing.
func TestSSMAuthTokenService_UpdateManagedInstancePublicKey_Success(t *testing.T) {
sdk := &mocks.ISsmSdk{}
response := &ssm.UpdateManagedInstancePublicKeyOutput{}
sdk.On("UpdateManagedInstancePublicKey", mock.Anything).Return(response, nil)
sdk.On("UpdateManagedInstancePublicKeyWithContext", mock.Anything, mock.Anything).Return(response, nil)
authTokenService := NewClient(sdk)
result, err := authTokenService.UpdateManagedInstancePublicKey("publicKey", "publicKeyType")
assert.NoError(t, err)
Expand Down
10 changes: 5 additions & 5 deletions agent/updateutil/updateutil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,14 @@ func TestUtility_setShareCredsEnvironment_SetsCommandAWSEnvironmentVariables_Whe
ctx.On("Identity").Return(agentIdentity)
ctx.On("Log").Return(log.NewMockLog())

remoteProvier := &mocks.IRemoteProvider{}
remoteProvier.On("SharesCredentials").Return(true)
remoteProvider := &mocks.IRemoteProvider{}
remoteProvider.On("SharesCredentials").Return(true)
expectedShareProfile := "SomeShareFileLocation"
expectedShareFile := "SomeShareFileLocation"
remoteProvier.On("ShareProfile").Return(expectedShareProfile)
remoteProvier.On("ShareFile").Return(expectedShareFile)
remoteProvider.On("ShareProfile").Return(expectedShareProfile)
remoteProvider.On("ShareFile").Return(expectedShareFile)
getRemoteProvider = func(agentIdentity identity.IAgentIdentity) (credentialproviders.IRemoteProvider, bool) {
return remoteProvier, true
return remoteProvider, true
}

utility := &Utility{
Expand Down
15 changes: 13 additions & 2 deletions common/identity/availableidentities/ec2/stubs/provider_stub.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stubs

import (
"context"
"time"

"github.com/aws/amazon-ssm-agent/common/identity/credentialproviders/ec2roleprovider"
Expand All @@ -25,14 +26,24 @@ func (p *ProviderStub) SetExpiration(expiration time.Time, window time.Duration)
return
}

func (p *ProviderStub) Retrieve() (credentials.Value, error) {
func (p *ProviderStub) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
return credentials.Value{
ProviderName: p.ProviderName,
}, nil
}

func (p *ProviderStub) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}

func (p *ProviderStub) RemoteRetrieve() (credentials.Value, error) {
return p.Retrieve()
return p.RemoteRetrieveWithContext(context.Background())
}

func (p *ProviderStub) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) {
return credentials.Value{
ProviderName: p.ProviderName,
}, nil
}

func (p *ProviderStub) IsExpired() bool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package credentialproviders

import (
"context"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
Expand All @@ -27,6 +28,6 @@ type IRemoteProvider interface {
ShareFile() string
SharesCredentials() bool
CredentialSource() string
RemoteRetrieve() (credentials.Value, error)
RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error)
RemoteExpiresAt() time.Time
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ec2roleprovider

import (
"context"
"fmt"
"runtime"
"sync"
Expand Down Expand Up @@ -85,14 +86,14 @@ func (p *EC2RoleProvider) GetInnerProvider() IInnerProvider {
return p.InnerProviders.IPRProvider
}

// Retrieve returns shared credentials if specified in runtime config
// RetrieveWithContext returns shared credentials if specified in runtime config
// and returns instance profile role credentials otherwise.
// If neither can be retrieved then empty credentials are returned
func (p *EC2RoleProvider) Retrieve() (credentials.Value, error) {
func (p *EC2RoleProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
if runtimeConfig, err := p.RuntimeConfigClient.GetConfig(); err != nil {
p.Log.Errorf("Failed to read runtime config for ShareFile information")
} else if runtimeConfig.ShareFile != "" {
sharedCreds, err := p.InnerProviders.SharedCredentialsProvider.Retrieve()
sharedCreds, err := p.InnerProviders.SharedCredentialsProvider.RetrieveWithContext(ctx)
if err != nil {
err = fmt.Errorf("unable to load shared credentials. Err: %w", err)
p.Log.Error(err)
Expand All @@ -104,7 +105,7 @@ func (p *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}

p.credentialSource = CredentialSourceEC2
iprCredentials, err := p.InnerProviders.IPRProvider.Retrieve()
iprCredentials, err := p.InnerProviders.IPRProvider.RetrieveWithContext(ctx)
if err != nil {
err = fmt.Errorf("failed to retrieve instance profile role credentials. Err: %w", err)
p.Log.Error(err)
Expand All @@ -115,7 +116,7 @@ func (p *EC2RoleProvider) Retrieve() (credentials.Value, error) {
}

// RemoteRetrieve uses network calls to retrieve credentials for EC2 instances
func (p *EC2RoleProvider) RemoteRetrieve() (credentials.Value, error) {
func (p *EC2RoleProvider) RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error) {
p.Log.Debug("Attempting to retrieve instance profile role")
if iprCredentials, err := p.iprCredentials(p.SsmEndpoint); err != nil {
errCode := sdkutil.GetAwsErrorCode(err)
Expand Down Expand Up @@ -144,6 +145,12 @@ func (p *EC2RoleProvider) RemoteRetrieve() (credentials.Value, error) {
return iprEmptyCredential, fmt.Errorf("no valid credentials could be retrieved for ec2 identity")
}

// Retrieve returns instance profile role credentials if it has sufficient systems manager permissions and
// returns ssm provided credentials otherwise. If neither can be retrieved then empty credentials are returned
func (p *EC2RoleProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}

// iprCredentials retrieves instance profile role credentials and returns an error if the returned credentials cannot
// connect to Systems Manager
func (p *EC2RoleProvider) iprCredentials(ssmEndpoint string) (*credentials.Credentials, error) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ec2roleprovider

import (
"context"
"time"

"github.com/aws/amazon-ssm-agent/common/identity/credentialproviders"
Expand Down Expand Up @@ -46,7 +47,8 @@ var (
type IInnerProvider interface {
credentials.Provider
credentials.Expirer

Retrieve() (credentials.Value, error)
RetrieveWithContext(ctx context.Context) (credentials.Value, error)
SetExpiration(expiration time.Time, window time.Duration)
}

Expand All @@ -64,4 +66,6 @@ type IEC2RoleProvider interface {
ShareFile() string
ShareProfile() string
SharesCredentials() bool
RetrieveWithContext(ctx context.Context) (credentials.Value, error)
RemoteRetrieveWithContext(ctx context.Context) (credentials.Value, error)
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package ec2roleprovider

import (
"context"
"fmt"
"sync"
"testing"
Expand Down Expand Up @@ -413,7 +414,7 @@ func TestEC2RoleProvider_RetrieveRemote_ReturnsEmptyCredentials(t *testing.T) {
ec2RoleProvider := arrangeRetrieveEmptyTest(j)

// Act
creds, err := ec2RoleProvider.RemoteRetrieve()
creds, err := ec2RoleProvider.RemoteRetrieveWithContext(context.Background())

//Assert
assert.Error(t, err)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package stubs

import (
"context"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
Expand All @@ -12,7 +13,7 @@ type InnerProvider struct {
Expiry time.Time
}

func (p *InnerProvider) Retrieve() (credentials.Value, error) {
func (p *InnerProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
if p.RetrieveErr != nil {
return credentials.Value{}, p.RetrieveErr
}
Expand All @@ -22,6 +23,10 @@ func (p *InnerProvider) Retrieve() (credentials.Value, error) {
}, nil
}

func (p *InnerProvider) Retrieve() (credentials.Value, error) {
return p.RetrieveWithContext(context.Background())
}

func (p *InnerProvider) IsExpired() bool {
return p.RetrieveErr != nil
}
Expand Down
Loading

0 comments on commit 5996b37

Please sign in to comment.