Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ linters:
- "!**/pkg/auth/factory/**"
- "!**/pkg/auth/types/aws_credentials.go"
- "!**/pkg/auth/types/github_oidc_credentials.go"
- "!**/internal/aws_utils/**"
- "$test"
deny:
# AWS: Identity and auth-related SDKs
Expand Down
1 change: 1 addition & 0 deletions errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ var (
ErrInvalidTerraformSingleComponentAndMultiComponentFlags = errors.New("the single-component flags (`--from-plan`, `--planfile`) can't be used with the multi-component (bulk operations) flags (`--affected`, `--all`, `--query`, `--components`)")

ErrYamlFuncInvalidArguments = errors.New("invalid number of arguments in the Atmos YAML function")
ErrAwsGetCallerIdentity = errors.New("failed to get AWS caller identity")
ErrDescribeComponent = errors.New("failed to describe component")
ErrReadTerraformState = errors.New("failed to read Terraform state")
ErrEvaluateTerraformBackendVariable = errors.New("failed to evaluate terraform backend variable")
Expand Down
53 changes: 52 additions & 1 deletion internal/aws_utils/aws_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func LoadAWSConfigWithAuth(
baseCfg, err := config.LoadDefaultConfig(ctx, cfgOpts...)
if err != nil {
log.Debug("Failed to load AWS config", "error", err)
return aws.Config{}, fmt.Errorf("%w: %v", errUtils.ErrLoadAwsConfig, err)
return aws.Config{}, fmt.Errorf("%w: %w", errUtils.ErrLoadAwsConfig, err)
}
log.Debug("Successfully loaded AWS SDK config", "region", baseCfg.Region)

Expand Down Expand Up @@ -126,3 +126,54 @@ func LoadAWSConfig(ctx context.Context, region string, roleArn string, assumeRol

return LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, nil)
}

// AWSCallerIdentityResult holds the result of GetAWSCallerIdentity.
type AWSCallerIdentityResult struct {
Account string
Arn string
UserID string
Region string
}

// GetAWSCallerIdentity retrieves AWS caller identity using STS GetCallerIdentity API.
// Returns account ID, ARN, user ID, and region.
// This function keeps AWS SDK STS imports contained within aws_utils package.
func GetAWSCallerIdentity(
ctx context.Context,
region string,
roleArn string,
assumeRoleDuration time.Duration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentityResult, error) {
defer perf.Track(nil, "aws_utils.GetAWSCallerIdentity")()

// Load AWS config.
cfg, err := LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, authContext)
if err != nil {
return nil, err
}

// Create STS client and get caller identity.
stsClient := sts.NewFromConfig(cfg)
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
if err != nil {
return nil, fmt.Errorf("%w: %w", errUtils.ErrAwsGetCallerIdentity, err)
}

result := &AWSCallerIdentityResult{
Region: cfg.Region,
}

// Extract values from pointers.
if output.Account != nil {
result.Account = *output.Account
}
if output.Arn != nil {
result.Arn = *output.Arn
}
if output.UserId != nil {
result.UserID = *output.UserId
}

return result, nil
}
163 changes: 163 additions & 0 deletions internal/exec/aws_getter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
package exec

import (
"context"
"fmt"
"sync"

awsUtils "github.com/cloudposse/atmos/internal/aws_utils"
log "github.com/cloudposse/atmos/pkg/logger"
"github.com/cloudposse/atmos/pkg/perf"
"github.com/cloudposse/atmos/pkg/schema"
)

// AWSCallerIdentity holds the information returned by AWS STS GetCallerIdentity.
type AWSCallerIdentity struct {
Account string
Arn string
UserID string
Region string // The AWS region from the loaded config.
}

// AWSGetter provides an interface for retrieving AWS caller identity information.
// This interface enables dependency injection and testability.
//
//go:generate go run go.uber.org/mock/mockgen@v0.6.0 -source=$GOFILE -destination=mock_aws_getter_test.go -package=exec
type AWSGetter interface {
// GetCallerIdentity retrieves the AWS caller identity for the current credentials.
// Returns the account ID, ARN, and user ID of the calling identity.
GetCallerIdentity(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error)
}

// defaultAWSGetter is the production implementation that uses real AWS SDK calls.
type defaultAWSGetter struct{}

// GetCallerIdentity retrieves the AWS caller identity using the STS GetCallerIdentity API.
func (d *defaultAWSGetter) GetCallerIdentity(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error) {
defer perf.Track(atmosConfig, "exec.AWSGetter.GetCallerIdentity")()

log.Debug("Getting AWS caller identity")

// Use the aws_utils helper to get caller identity (keeps AWS SDK imports in aws_utils).
result, err := awsUtils.GetAWSCallerIdentity(ctx, "", "", 0, authContext)
if err != nil {
return nil, err // Error already wrapped by aws_utils.
}

identity := &AWSCallerIdentity{
Account: result.Account,
Arn: result.Arn,
UserID: result.UserID,
Region: result.Region,
}

log.Debug("Retrieved AWS caller identity",
"account", identity.Account,
"arn", identity.Arn,
"user_id", identity.UserID,
"region", identity.Region,
)

return identity, nil
}

// awsGetter is the global instance used by YAML functions.
// This allows test code to replace it with a mock.
var awsGetter AWSGetter = &defaultAWSGetter{}

// SetAWSGetter allows tests to inject a mock AWSGetter.
// Returns a function to restore the original getter.
func SetAWSGetter(getter AWSGetter) func() {
defer perf.Track(nil, "exec.SetAWSGetter")()

original := awsGetter
awsGetter = getter
return func() {
awsGetter = original
}
}

// cachedAWSIdentity holds the cached AWS caller identity.
// The cache is per-CLI-invocation (stored in memory) to avoid repeated STS calls.
type cachedAWSIdentity struct {
identity *AWSCallerIdentity
err error
}

var (
awsIdentityCache map[string]*cachedAWSIdentity
awsIdentityCacheMu sync.RWMutex
)

func init() {
awsIdentityCache = make(map[string]*cachedAWSIdentity)
}

// getCacheKey generates a cache key based on the auth context.
// Different auth contexts (different credentials) get different cache entries.
func getCacheKey(authContext *schema.AWSAuthContext) string {
if authContext == nil {
return "default"
}
return fmt.Sprintf("%s:%s", authContext.Profile, authContext.CredentialsFile)
}
Comment on lines +104 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

ast-grep --pattern $'type AWSAuthContext struct {
  $$$
}'

Repository: cloudposse/atmos

Length of output: 1228


🏁 Script executed:

cat -n internal/exec/aws_getter.go | head -150

Repository: cloudposse/atmos

Length of output: 5752


🏁 Script executed:

find . -type f -name "*.go" | xargs grep -l "GetAWSCallerIdentity" | grep aws_utils

Repository: cloudposse/atmos

Length of output: 93


🏁 Script executed:

cat -n internal/aws_utils/aws_utils.go | head -200

Repository: cloudposse/atmos

Length of output: 6823


Add ConfigFile to cache key generation.

The cache key only uses Profile and CredentialsFile, but ConfigFile is also used during AWS credential loading (passed to config.WithSharedConfigFiles()). Different ConfigFile values can result in different AWS configurations and may affect identity resolution. Update getCacheKey() to include ConfigFile in the cache key to ensure different configurations don't incorrectly share cache entries.

Additionally, consider whether Region should be included, as it's used during config loading when not explicitly provided (line 82-83 in aws_utils.go).

🤖 Prompt for AI Agents
In internal/exec/aws_getter.go around lines 104 to 111, the cache key currently
only includes Profile and CredentialsFile which can cause cache collisions when
different shared AWS config files are used; update getCacheKey to also
incorporate authContext.ConfigFile (and include Region if authContext.Region may
influence loaded config) so the returned key reflects
Profile:CredentialsFile:ConfigFile (and optionally :Region) to ensure distinct
auth contexts produce distinct cache entries.


// getAWSCallerIdentityCached retrieves the AWS caller identity with caching.
// Results are cached per auth context to avoid repeated STS calls within the same CLI invocation.
func getAWSCallerIdentityCached(
ctx context.Context,
atmosConfig *schema.AtmosConfiguration,
authContext *schema.AWSAuthContext,
) (*AWSCallerIdentity, error) {
defer perf.Track(atmosConfig, "exec.getAWSCallerIdentityCached")()

cacheKey := getCacheKey(authContext)

// Check cache first (read lock).
awsIdentityCacheMu.RLock()
if cached, ok := awsIdentityCache[cacheKey]; ok {
awsIdentityCacheMu.RUnlock()
log.Debug("Using cached AWS caller identity", "cache_key", cacheKey)
return cached.identity, cached.err
}
awsIdentityCacheMu.RUnlock()

// Cache miss - acquire write lock and fetch.
awsIdentityCacheMu.Lock()
defer awsIdentityCacheMu.Unlock()

// Double-check after acquiring write lock.
if cached, ok := awsIdentityCache[cacheKey]; ok {
log.Debug("Using cached AWS caller identity (double-check)", "cache_key", cacheKey)
return cached.identity, cached.err
}

// Fetch from AWS.
identity, err := awsGetter.GetCallerIdentity(ctx, atmosConfig, authContext)

// Cache the result (including errors to avoid repeated failed calls).
awsIdentityCache[cacheKey] = &cachedAWSIdentity{
identity: identity,
err: err,
}

return identity, err
}

// ClearAWSIdentityCache clears the AWS identity cache.
// This is useful in tests or when credentials change during execution.
func ClearAWSIdentityCache() {
defer perf.Track(nil, "exec.ClearAWSIdentityCache")()

awsIdentityCacheMu.Lock()
defer awsIdentityCacheMu.Unlock()
awsIdentityCache = make(map[string]*cachedAWSIdentity)
}
151 changes: 151 additions & 0 deletions internal/exec/yaml_func_aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package exec

import (
"context"

errUtils "github.com/cloudposse/atmos/errors"
log "github.com/cloudposse/atmos/pkg/logger"
"github.com/cloudposse/atmos/pkg/perf"
"github.com/cloudposse/atmos/pkg/schema"
u "github.com/cloudposse/atmos/pkg/utils"
)

const (
execAWSYAMLFunction = "Executing Atmos YAML function"
invalidYAMLFunction = "Invalid YAML function"
failedGetIdentity = "Failed to get AWS caller identity"
functionKey = "function"
)

// processTagAwsValue is a shared helper for AWS YAML functions.
// It validates the input tag, retrieves AWS caller identity, and returns the requested value.
func processTagAwsValue(
atmosConfig *schema.AtmosConfiguration,
input string,
expectedTag string,
stackInfo *schema.ConfigAndStacksInfo,
extractor func(*AWSCallerIdentity) string,
) any {
log.Debug(execAWSYAMLFunction, functionKey, input)

// Validate the tag matches expected.
if input != expectedTag {
log.Error(invalidYAMLFunction, functionKey, input, "expected", expectedTag)
errUtils.CheckErrorPrintAndExit(errUtils.ErrYamlFuncInvalidArguments, "", "")
return nil
}

// Get auth context from stack info if available.
var authContext *schema.AWSAuthContext
if stackInfo != nil && stackInfo.AuthContext != nil && stackInfo.AuthContext.AWS != nil {
authContext = stackInfo.AuthContext.AWS
}

// Get the AWS caller identity (cached).
ctx := context.Background()
identity, err := getAWSCallerIdentityCached(ctx, atmosConfig, authContext)
if err != nil {
log.Error(failedGetIdentity, "error", err)
errUtils.CheckErrorPrintAndExit(err, "", "")
return nil
}

// Extract the requested value.
return extractor(identity)
}

// processTagAwsAccountID processes the !aws.account_id YAML function.
// It returns the AWS account ID of the current caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// account_id: !aws.account_id
func processTagAwsAccountID(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsAccountID")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsAccountID, stackInfo, func(id *AWSCallerIdentity) string {
return id.Account
})

if result != nil {
log.Debug("Resolved !aws.account_id", "account_id", result)
}
return result
}

// processTagAwsCallerIdentityArn processes the !aws.caller_identity_arn YAML function.
// It returns the ARN of the current AWS caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// caller_arn: !aws.caller_identity_arn
func processTagAwsCallerIdentityArn(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityArn")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityArn, stackInfo, func(id *AWSCallerIdentity) string {
return id.Arn
})

if result != nil {
log.Debug("Resolved !aws.caller_identity_arn", "arn", result)
}
return result
}

// processTagAwsCallerIdentityUserID processes the !aws.caller_identity_user_id YAML function.
// It returns the unique user ID of the current AWS caller identity.
// The function takes no parameters.
//
// Usage in YAML:
//
// user_id: !aws.caller_identity_user_id
func processTagAwsCallerIdentityUserID(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityUserID")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityUserID, stackInfo, func(id *AWSCallerIdentity) string {
return id.UserID
})

if result != nil {
log.Debug("Resolved !aws.caller_identity_user_id", "user_id", result)
}
return result
}

// processTagAwsRegion processes the !aws.region YAML function.
// It returns the AWS region from the current configuration.
// The function takes no parameters.
//
// Usage in YAML:
//
// region: !aws.region
func processTagAwsRegion(
atmosConfig *schema.AtmosConfiguration,
input string,
stackInfo *schema.ConfigAndStacksInfo,
) any {
defer perf.Track(atmosConfig, "exec.processTagAwsRegion")()

result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsRegion, stackInfo, func(id *AWSCallerIdentity) string {
return id.Region
})

if result != nil {
log.Debug("Resolved !aws.region", "region", result)
}
return result
}
Loading
Loading