Skip to content

Commit ca45d16

Browse files
ostermanclaude
andcommitted
Add AWS YAML functions: !aws.account_id, !aws.caller_identity_arn, !aws.caller_identity_user_id, and !aws.region
Implement four new AWS YAML functions that retrieve AWS identity and configuration information using the STS GetCallerIdentity API. These functions are equivalent to Terragrunt's corresponding helper functions and provide seamless integration with Atmos authentication contexts. All functions share a unified caching mechanism to minimize API calls. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 38eb63c commit ca45d16

File tree

13 files changed

+1658
-18
lines changed

13 files changed

+1658
-18
lines changed

.golangci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ linters:
6666
- "!**/pkg/auth/factory/**"
6767
- "!**/pkg/auth/types/aws_credentials.go"
6868
- "!**/pkg/auth/types/github_oidc_credentials.go"
69+
- "!**/internal/aws_utils/**"
6970
- "$test"
7071
deny:
7172
# AWS: Identity and auth-related SDKs

errors/errors.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ var (
8787
ErrInvalidTerraformSingleComponentAndMultiComponentFlags = errors.New("the single-component flags (`--from-plan`, `--planfile`) can't be used with the multi-component (bulk operations) flags (`--affected`, `--all`, `--query`, `--components`)")
8888

8989
ErrYamlFuncInvalidArguments = errors.New("invalid number of arguments in the Atmos YAML function")
90+
ErrAwsGetCallerIdentity = errors.New("failed to get AWS caller identity")
9091
ErrDescribeComponent = errors.New("failed to describe component")
9192
ErrReadTerraformState = errors.New("failed to read Terraform state")
9293
ErrEvaluateTerraformBackendVariable = errors.New("failed to evaluate terraform backend variable")

internal/aws_utils/aws_utils.go

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func LoadAWSConfigWithAuth(
9696
baseCfg, err := config.LoadDefaultConfig(ctx, cfgOpts...)
9797
if err != nil {
9898
log.Debug("Failed to load AWS config", "error", err)
99-
return aws.Config{}, fmt.Errorf("%w: %v", errUtils.ErrLoadAwsConfig, err)
99+
return aws.Config{}, fmt.Errorf("%w: %w", errUtils.ErrLoadAwsConfig, err)
100100
}
101101
log.Debug("Successfully loaded AWS SDK config", "region", baseCfg.Region)
102102

@@ -126,3 +126,54 @@ func LoadAWSConfig(ctx context.Context, region string, roleArn string, assumeRol
126126

127127
return LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, nil)
128128
}
129+
130+
// AWSCallerIdentityResult holds the result of GetAWSCallerIdentity.
131+
type AWSCallerIdentityResult struct {
132+
Account string
133+
Arn string
134+
UserID string
135+
Region string
136+
}
137+
138+
// GetAWSCallerIdentity retrieves AWS caller identity using STS GetCallerIdentity API.
139+
// Returns account ID, ARN, user ID, and region.
140+
// This function keeps AWS SDK STS imports contained within aws_utils package.
141+
func GetAWSCallerIdentity(
142+
ctx context.Context,
143+
region string,
144+
roleArn string,
145+
assumeRoleDuration time.Duration,
146+
authContext *schema.AWSAuthContext,
147+
) (*AWSCallerIdentityResult, error) {
148+
defer perf.Track(nil, "aws_utils.GetAWSCallerIdentity")()
149+
150+
// Load AWS config.
151+
cfg, err := LoadAWSConfigWithAuth(ctx, region, roleArn, assumeRoleDuration, authContext)
152+
if err != nil {
153+
return nil, err
154+
}
155+
156+
// Create STS client and get caller identity.
157+
stsClient := sts.NewFromConfig(cfg)
158+
output, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{})
159+
if err != nil {
160+
return nil, fmt.Errorf("%w: %w", errUtils.ErrAwsGetCallerIdentity, err)
161+
}
162+
163+
result := &AWSCallerIdentityResult{
164+
Region: cfg.Region,
165+
}
166+
167+
// Extract values from pointers.
168+
if output.Account != nil {
169+
result.Account = *output.Account
170+
}
171+
if output.Arn != nil {
172+
result.Arn = *output.Arn
173+
}
174+
if output.UserId != nil {
175+
result.UserID = *output.UserId
176+
}
177+
178+
return result, nil
179+
}

internal/exec/aws_getter.go

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
package exec
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sync"
7+
8+
awsUtils "github.com/cloudposse/atmos/internal/aws_utils"
9+
log "github.com/cloudposse/atmos/pkg/logger"
10+
"github.com/cloudposse/atmos/pkg/perf"
11+
"github.com/cloudposse/atmos/pkg/schema"
12+
)
13+
14+
// AWSCallerIdentity holds the information returned by AWS STS GetCallerIdentity.
15+
type AWSCallerIdentity struct {
16+
Account string
17+
Arn string
18+
UserID string
19+
Region string // The AWS region from the loaded config.
20+
}
21+
22+
// AWSGetter provides an interface for retrieving AWS caller identity information.
23+
// This interface enables dependency injection and testability.
24+
//
25+
//go:generate go run go.uber.org/mock/mockgen@v0.6.0 -source=$GOFILE -destination=mock_aws_getter_test.go -package=exec
26+
type AWSGetter interface {
27+
// GetCallerIdentity retrieves the AWS caller identity for the current credentials.
28+
// Returns the account ID, ARN, and user ID of the calling identity.
29+
GetCallerIdentity(
30+
ctx context.Context,
31+
atmosConfig *schema.AtmosConfiguration,
32+
authContext *schema.AWSAuthContext,
33+
) (*AWSCallerIdentity, error)
34+
}
35+
36+
// defaultAWSGetter is the production implementation that uses real AWS SDK calls.
37+
type defaultAWSGetter struct{}
38+
39+
// GetCallerIdentity retrieves the AWS caller identity using the STS GetCallerIdentity API.
40+
func (d *defaultAWSGetter) GetCallerIdentity(
41+
ctx context.Context,
42+
atmosConfig *schema.AtmosConfiguration,
43+
authContext *schema.AWSAuthContext,
44+
) (*AWSCallerIdentity, error) {
45+
defer perf.Track(atmosConfig, "exec.AWSGetter.GetCallerIdentity")()
46+
47+
log.Debug("Getting AWS caller identity")
48+
49+
// Use the aws_utils helper to get caller identity (keeps AWS SDK imports in aws_utils).
50+
result, err := awsUtils.GetAWSCallerIdentity(ctx, "", "", 0, authContext)
51+
if err != nil {
52+
return nil, err // Error already wrapped by aws_utils.
53+
}
54+
55+
identity := &AWSCallerIdentity{
56+
Account: result.Account,
57+
Arn: result.Arn,
58+
UserID: result.UserID,
59+
Region: result.Region,
60+
}
61+
62+
log.Debug("Retrieved AWS caller identity",
63+
"account", identity.Account,
64+
"arn", identity.Arn,
65+
"user_id", identity.UserID,
66+
"region", identity.Region,
67+
)
68+
69+
return identity, nil
70+
}
71+
72+
// awsGetter is the global instance used by YAML functions.
73+
// This allows test code to replace it with a mock.
74+
var awsGetter AWSGetter = &defaultAWSGetter{}
75+
76+
// SetAWSGetter allows tests to inject a mock AWSGetter.
77+
// Returns a function to restore the original getter.
78+
func SetAWSGetter(getter AWSGetter) func() {
79+
defer perf.Track(nil, "exec.SetAWSGetter")()
80+
81+
original := awsGetter
82+
awsGetter = getter
83+
return func() {
84+
awsGetter = original
85+
}
86+
}
87+
88+
// cachedAWSIdentity holds the cached AWS caller identity.
89+
// The cache is per-CLI-invocation (stored in memory) to avoid repeated STS calls.
90+
type cachedAWSIdentity struct {
91+
identity *AWSCallerIdentity
92+
err error
93+
}
94+
95+
var (
96+
awsIdentityCache map[string]*cachedAWSIdentity
97+
awsIdentityCacheMu sync.RWMutex
98+
)
99+
100+
func init() {
101+
awsIdentityCache = make(map[string]*cachedAWSIdentity)
102+
}
103+
104+
// getCacheKey generates a cache key based on the auth context.
105+
// Different auth contexts (different credentials) get different cache entries.
106+
func getCacheKey(authContext *schema.AWSAuthContext) string {
107+
if authContext == nil {
108+
return "default"
109+
}
110+
return fmt.Sprintf("%s:%s", authContext.Profile, authContext.CredentialsFile)
111+
}
112+
113+
// getAWSCallerIdentityCached retrieves the AWS caller identity with caching.
114+
// Results are cached per auth context to avoid repeated STS calls within the same CLI invocation.
115+
func getAWSCallerIdentityCached(
116+
ctx context.Context,
117+
atmosConfig *schema.AtmosConfiguration,
118+
authContext *schema.AWSAuthContext,
119+
) (*AWSCallerIdentity, error) {
120+
defer perf.Track(atmosConfig, "exec.getAWSCallerIdentityCached")()
121+
122+
cacheKey := getCacheKey(authContext)
123+
124+
// Check cache first (read lock).
125+
awsIdentityCacheMu.RLock()
126+
if cached, ok := awsIdentityCache[cacheKey]; ok {
127+
awsIdentityCacheMu.RUnlock()
128+
log.Debug("Using cached AWS caller identity", "cache_key", cacheKey)
129+
return cached.identity, cached.err
130+
}
131+
awsIdentityCacheMu.RUnlock()
132+
133+
// Cache miss - acquire write lock and fetch.
134+
awsIdentityCacheMu.Lock()
135+
defer awsIdentityCacheMu.Unlock()
136+
137+
// Double-check after acquiring write lock.
138+
if cached, ok := awsIdentityCache[cacheKey]; ok {
139+
log.Debug("Using cached AWS caller identity (double-check)", "cache_key", cacheKey)
140+
return cached.identity, cached.err
141+
}
142+
143+
// Fetch from AWS.
144+
identity, err := awsGetter.GetCallerIdentity(ctx, atmosConfig, authContext)
145+
146+
// Cache the result (including errors to avoid repeated failed calls).
147+
awsIdentityCache[cacheKey] = &cachedAWSIdentity{
148+
identity: identity,
149+
err: err,
150+
}
151+
152+
return identity, err
153+
}
154+
155+
// ClearAWSIdentityCache clears the AWS identity cache.
156+
// This is useful in tests or when credentials change during execution.
157+
func ClearAWSIdentityCache() {
158+
defer perf.Track(nil, "exec.ClearAWSIdentityCache")()
159+
160+
awsIdentityCacheMu.Lock()
161+
defer awsIdentityCacheMu.Unlock()
162+
awsIdentityCache = make(map[string]*cachedAWSIdentity)
163+
}

internal/exec/yaml_func_aws.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package exec
2+
3+
import (
4+
"context"
5+
6+
errUtils "github.com/cloudposse/atmos/errors"
7+
log "github.com/cloudposse/atmos/pkg/logger"
8+
"github.com/cloudposse/atmos/pkg/perf"
9+
"github.com/cloudposse/atmos/pkg/schema"
10+
u "github.com/cloudposse/atmos/pkg/utils"
11+
)
12+
13+
const (
14+
execAWSYAMLFunction = "Executing Atmos YAML function"
15+
invalidYAMLFunction = "Invalid YAML function"
16+
failedGetIdentity = "Failed to get AWS caller identity"
17+
functionKey = "function"
18+
)
19+
20+
// processTagAwsValue is a shared helper for AWS YAML functions.
21+
// It validates the input tag, retrieves AWS caller identity, and returns the requested value.
22+
func processTagAwsValue(
23+
atmosConfig *schema.AtmosConfiguration,
24+
input string,
25+
expectedTag string,
26+
stackInfo *schema.ConfigAndStacksInfo,
27+
extractor func(*AWSCallerIdentity) string,
28+
) any {
29+
log.Debug(execAWSYAMLFunction, functionKey, input)
30+
31+
// Validate the tag matches expected.
32+
if input != expectedTag {
33+
log.Error(invalidYAMLFunction, functionKey, input, "expected", expectedTag)
34+
errUtils.CheckErrorPrintAndExit(errUtils.ErrYamlFuncInvalidArguments, "", "")
35+
return nil
36+
}
37+
38+
// Get auth context from stack info if available.
39+
var authContext *schema.AWSAuthContext
40+
if stackInfo != nil && stackInfo.AuthContext != nil && stackInfo.AuthContext.AWS != nil {
41+
authContext = stackInfo.AuthContext.AWS
42+
}
43+
44+
// Get the AWS caller identity (cached).
45+
ctx := context.Background()
46+
identity, err := getAWSCallerIdentityCached(ctx, atmosConfig, authContext)
47+
if err != nil {
48+
log.Error(failedGetIdentity, "error", err)
49+
errUtils.CheckErrorPrintAndExit(err, "", "")
50+
return nil
51+
}
52+
53+
// Extract the requested value.
54+
return extractor(identity)
55+
}
56+
57+
// processTagAwsAccountID processes the !aws.account_id YAML function.
58+
// It returns the AWS account ID of the current caller identity.
59+
// The function takes no parameters.
60+
//
61+
// Usage in YAML:
62+
//
63+
// account_id: !aws.account_id
64+
func processTagAwsAccountID(
65+
atmosConfig *schema.AtmosConfiguration,
66+
input string,
67+
stackInfo *schema.ConfigAndStacksInfo,
68+
) any {
69+
defer perf.Track(atmosConfig, "exec.processTagAwsAccountID")()
70+
71+
result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsAccountID, stackInfo, func(id *AWSCallerIdentity) string {
72+
return id.Account
73+
})
74+
75+
if result != nil {
76+
log.Debug("Resolved !aws.account_id", "account_id", result)
77+
}
78+
return result
79+
}
80+
81+
// processTagAwsCallerIdentityArn processes the !aws.caller_identity_arn YAML function.
82+
// It returns the ARN of the current AWS caller identity.
83+
// The function takes no parameters.
84+
//
85+
// Usage in YAML:
86+
//
87+
// caller_arn: !aws.caller_identity_arn
88+
func processTagAwsCallerIdentityArn(
89+
atmosConfig *schema.AtmosConfiguration,
90+
input string,
91+
stackInfo *schema.ConfigAndStacksInfo,
92+
) any {
93+
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityArn")()
94+
95+
result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityArn, stackInfo, func(id *AWSCallerIdentity) string {
96+
return id.Arn
97+
})
98+
99+
if result != nil {
100+
log.Debug("Resolved !aws.caller_identity_arn", "arn", result)
101+
}
102+
return result
103+
}
104+
105+
// processTagAwsCallerIdentityUserID processes the !aws.caller_identity_user_id YAML function.
106+
// It returns the unique user ID of the current AWS caller identity.
107+
// The function takes no parameters.
108+
//
109+
// Usage in YAML:
110+
//
111+
// user_id: !aws.caller_identity_user_id
112+
func processTagAwsCallerIdentityUserID(
113+
atmosConfig *schema.AtmosConfiguration,
114+
input string,
115+
stackInfo *schema.ConfigAndStacksInfo,
116+
) any {
117+
defer perf.Track(atmosConfig, "exec.processTagAwsCallerIdentityUserID")()
118+
119+
result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsCallerIdentityUserID, stackInfo, func(id *AWSCallerIdentity) string {
120+
return id.UserID
121+
})
122+
123+
if result != nil {
124+
log.Debug("Resolved !aws.caller_identity_user_id", "user_id", result)
125+
}
126+
return result
127+
}
128+
129+
// processTagAwsRegion processes the !aws.region YAML function.
130+
// It returns the AWS region from the current configuration.
131+
// The function takes no parameters.
132+
//
133+
// Usage in YAML:
134+
//
135+
// region: !aws.region
136+
func processTagAwsRegion(
137+
atmosConfig *schema.AtmosConfiguration,
138+
input string,
139+
stackInfo *schema.ConfigAndStacksInfo,
140+
) any {
141+
defer perf.Track(atmosConfig, "exec.processTagAwsRegion")()
142+
143+
result := processTagAwsValue(atmosConfig, input, u.AtmosYamlFuncAwsRegion, stackInfo, func(id *AWSCallerIdentity) string {
144+
return id.Region
145+
})
146+
147+
if result != nil {
148+
log.Debug("Resolved !aws.region", "region", result)
149+
}
150+
return result
151+
}

0 commit comments

Comments
 (0)