|
4 | 4 | package awsutil
|
5 | 5 |
|
6 | 6 | import (
|
7 |
| - "errors" |
| 7 | + "context" |
8 | 8 | "fmt"
|
9 | 9 |
|
10 |
| - "github.com/aws/aws-sdk-go/aws/session" |
11 |
| - "github.com/aws/aws-sdk-go/service/iam" |
12 |
| - "github.com/aws/aws-sdk-go/service/iam/iamiface" |
13 |
| - "github.com/aws/aws-sdk-go/service/sts" |
14 |
| - "github.com/aws/aws-sdk-go/service/sts/stsiface" |
| 10 | + "github.com/aws/aws-sdk-go-v2/aws" |
| 11 | + "github.com/aws/aws-sdk-go-v2/service/iam" |
| 12 | + "github.com/aws/aws-sdk-go-v2/service/sts" |
15 | 13 | )
|
16 | 14 |
|
17 | 15 | // IAMAPIFunc is a factory function for returning an IAM interface,
|
18 |
| -// useful for supplying mock interfaces for testing IAM. The session |
19 |
| -// is passed into the function in the same way as done with the |
20 |
| -// standard iam.New() constructor. |
21 |
| -type IAMAPIFunc func(sess *session.Session) (iamiface.IAMAPI, error) |
| 16 | +// useful for supplying mock interfaces for testing IAM. |
| 17 | +type IAMAPIFunc func(awsConfig *aws.Config) (IAMClient, error) |
| 18 | + |
| 19 | +// IAMClient represents an iam.Client |
| 20 | +type IAMClient interface { |
| 21 | + CreateAccessKey(context.Context, *iam.CreateAccessKeyInput, ...func(*iam.Options)) (*iam.CreateAccessKeyOutput, error) |
| 22 | + DeleteAccessKey(context.Context, *iam.DeleteAccessKeyInput, ...func(*iam.Options)) (*iam.DeleteAccessKeyOutput, error) |
| 23 | + ListAccessKeys(context.Context, *iam.ListAccessKeysInput, ...func(*iam.Options)) (*iam.ListAccessKeysOutput, error) |
| 24 | + GetUser(context.Context, *iam.GetUserInput, ...func(*iam.Options)) (*iam.GetUserOutput, error) |
| 25 | +} |
22 | 26 |
|
23 | 27 | // STSAPIFunc is a factory function for returning a STS interface,
|
24 |
| -// useful for supplying mock interfaces for testing STS. The session |
25 |
| -// is passed into the function in the same way as done with the |
26 |
| -// standard sts.New() constructor. |
27 |
| -type STSAPIFunc func(sess *session.Session) (stsiface.STSAPI, error) |
| 28 | +// useful for supplying mock interfaces for testing STS. |
| 29 | +type STSAPIFunc func(awsConfig *aws.Config) (STSClient, error) |
| 30 | + |
| 31 | +// STSClient represents an sts.Client |
| 32 | +type STSClient interface { |
| 33 | + AssumeRole(context.Context, *sts.AssumeRoleInput, ...func(*sts.Options)) (*sts.AssumeRoleOutput, error) |
| 34 | + GetCallerIdentity(context.Context, *sts.GetCallerIdentityInput, ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) |
| 35 | +} |
28 | 36 |
|
29 | 37 | // IAMClient returns an IAM client.
|
30 | 38 | //
|
31 |
| -// Supported options: WithSession, WithIAMAPIFunc. |
| 39 | +// Supported options: WithAwsConfig, WithIAMAPIFunc, WithIamEndpointResolver. |
32 | 40 | //
|
33 | 41 | // If WithIAMAPIFunc is supplied, the included function is used as
|
34 | 42 | // the IAM client constructor instead. This can be used for Mocking
|
35 | 43 | // the IAM API.
|
36 |
| -func (c *CredentialsConfig) IAMClient(opt ...Option) (iamiface.IAMAPI, error) { |
| 44 | +func (c *CredentialsConfig) IAMClient(ctx context.Context, opt ...Option) (IAMClient, error) { |
37 | 45 | opts, err := getOpts(opt...)
|
38 | 46 | if err != nil {
|
39 | 47 | return nil, fmt.Errorf("error reading options: %w", err)
|
40 | 48 | }
|
41 | 49 |
|
42 |
| - sess := opts.withAwsSession |
43 |
| - if sess == nil { |
44 |
| - sess, err = c.GetSession(opt...) |
| 50 | + cfg := opts.withAwsConfig |
| 51 | + if cfg == nil { |
| 52 | + cfg, err = c.GenerateCredentialChain(ctx, opt...) |
45 | 53 | if err != nil {
|
46 |
| - return nil, fmt.Errorf("error calling GetSession: %w", err) |
| 54 | + return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) |
47 | 55 | }
|
48 | 56 | }
|
49 | 57 |
|
50 | 58 | if opts.withIAMAPIFunc != nil {
|
51 |
| - return opts.withIAMAPIFunc(sess) |
| 59 | + return opts.withIAMAPIFunc(cfg) |
52 | 60 | }
|
53 | 61 |
|
54 |
| - client := iam.New(sess) |
55 |
| - if client == nil { |
56 |
| - return nil, errors.New("could not obtain iam client from session") |
| 62 | + var iamOpts []func(*iam.Options) |
| 63 | + if c.IAMEndpointResolver != nil { |
| 64 | + iamOpts = append(iamOpts, iam.WithEndpointResolverV2(c.IAMEndpointResolver)) |
57 | 65 | }
|
58 | 66 |
|
59 |
| - return client, nil |
| 67 | + return iam.NewFromConfig(*cfg, iamOpts...), nil |
60 | 68 | }
|
61 | 69 |
|
62 | 70 | // STSClient returns a STS client.
|
63 | 71 | //
|
64 |
| -// Supported options: WithSession, WithSTSAPIFunc. |
| 72 | +// Supported options: WithAwsConfig, WithSTSAPIFunc, WithStsEndpointResolver. |
65 | 73 | //
|
66 | 74 | // If WithSTSAPIFunc is supplied, the included function is used as
|
67 | 75 | // the STS client constructor instead. This can be used for Mocking
|
68 | 76 | // the STS API.
|
69 |
| -func (c *CredentialsConfig) STSClient(opt ...Option) (stsiface.STSAPI, error) { |
| 77 | +func (c *CredentialsConfig) STSClient(ctx context.Context, opt ...Option) (STSClient, error) { |
70 | 78 | opts, err := getOpts(opt...)
|
71 | 79 | if err != nil {
|
72 | 80 | return nil, fmt.Errorf("error reading options: %w", err)
|
73 | 81 | }
|
74 | 82 |
|
75 |
| - sess := opts.withAwsSession |
76 |
| - if sess == nil { |
77 |
| - sess, err = c.GetSession(opt...) |
| 83 | + cfg := opts.withAwsConfig |
| 84 | + if cfg == nil { |
| 85 | + cfg, err = c.GenerateCredentialChain(ctx, opt...) |
78 | 86 | if err != nil {
|
79 |
| - return nil, fmt.Errorf("error calling GetSession: %w", err) |
| 87 | + return nil, fmt.Errorf("error calling GenerateCredentialChain: %w", err) |
80 | 88 | }
|
81 | 89 | }
|
82 | 90 |
|
83 | 91 | if opts.withSTSAPIFunc != nil {
|
84 |
| - return opts.withSTSAPIFunc(sess) |
| 92 | + return opts.withSTSAPIFunc(cfg) |
85 | 93 | }
|
86 | 94 |
|
87 |
| - client := sts.New(sess) |
88 |
| - if client == nil { |
89 |
| - return nil, errors.New("could not obtain sts client from session") |
| 95 | + var stsOpts []func(*sts.Options) |
| 96 | + if c.STSEndpointResolver != nil { |
| 97 | + stsOpts = append(stsOpts, sts.WithEndpointResolverV2(c.STSEndpointResolver)) |
90 | 98 | }
|
91 | 99 |
|
92 |
| - return client, nil |
| 100 | + return sts.NewFromConfig(*cfg, stsOpts...), nil |
93 | 101 | }
|
0 commit comments