Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 89 additions & 93 deletions get_s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ import (
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
"github.com/aws/aws-sdk-go-v2/feature/ec2/imds"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/hashicorp/aws-sdk-go-base/v2/endpoints"
)

// S3Getter is a Getter implementation that will download a module from
Expand Down Expand Up @@ -54,24 +55,27 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) {
}

// List the object(s) at the given prefix
req := &s3.ListObjectsInput{
req := &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(path),
}
resp, err := client.ListObjectsWithContext(ctx, req)
if err != nil {
return 0, err
}

for _, o := range resp.Contents {
// Use file mode on exact match.
if *o.Key == path {
return ClientModeFile, nil
paginator := s3.NewListObjectsV2Paginator(client, req)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return 0, err
}

// Use dir mode if child keys are found.
if strings.HasPrefix(*o.Key, path+"/") {
return ClientModeDir, nil
for _, o := range output.Contents {
// Use file mode on exact match.
if aws.ToString(o.Key) == path {
return ClientModeFile, nil
}

// Use dir mode if child keys are found.
if strings.HasPrefix(aws.ToString(o.Key), path+"/") {
return ClientModeDir, nil
}
}
}

Expand Down Expand Up @@ -119,28 +123,19 @@ func (g *S3Getter) Get(dst string, u *url.URL) error {
}

// List files in path, keep listing until no more objects are found
lastMarker := ""
hasMore := true
for hasMore {
req := &s3.ListObjectsInput{
Bucket: aws.String(bucket),
Prefix: aws.String(path),
}
if lastMarker != "" {
req.Marker = aws.String(lastMarker)
}

resp, err := client.ListObjectsWithContext(ctx, req)
req := &s3.ListObjectsV2Input{
Bucket: aws.String(bucket),
Prefix: aws.String(path),
}
paginator := s3.NewListObjectsV2Paginator(client, req)
for paginator.HasMorePages() {
output, err := paginator.NextPage(ctx)
if err != nil {
return err
}

hasMore = aws.BoolValue(resp.IsTruncated)

// Get each object storing each file relative to the destination path
for _, object := range resp.Contents {
lastMarker = aws.StringValue(object.Key)
objPath := aws.StringValue(object.Key)
for _, object := range output.Contents {
objPath := aws.ToString(object.Key)

// If the key ends with a backslash assume it is a directory and ignore
if strings.HasSuffix(objPath, "/") {
Expand Down Expand Up @@ -185,7 +180,7 @@ func (g *S3Getter) GetFile(dst string, u *url.URL) error {
return g.getObject(ctx, client, dst, bucket, path, version)
}

func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error {
func (g *S3Getter) getObject(ctx context.Context, client *s3.Client, dst, bucket, key, version string) error {
req := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
Expand All @@ -194,7 +189,7 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke
req.VersionId = aws.String(version)
}

resp, err := client.GetObjectWithContext(ctx, req)
resp, err := client.GetObject(ctx, req)
if err != nil {
return err
}
Expand All @@ -208,57 +203,62 @@ func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, ke

if g.client != nil && g.client.ProgressListener != nil {
fn := filepath.Base(key)
body = g.client.ProgressListener.TrackProgress(fn, 0, *resp.ContentLength, resp.Body)
body = g.client.ProgressListener.TrackProgress(fn, 0, aws.ToInt64(resp.ContentLength), resp.Body)
}
defer func() { _ = body.Close() }()

// There is no limit set for the size of an object from S3
return copyReader(dst, body, 0666, g.client.umask(), 0)
}

func (g *S3Getter) getAWSConfig(region string, url *url.URL, creds *credentials.Credentials) (*aws.Config, error) {
conf := &aws.Config{}
metadataURLOverride := os.Getenv("AWS_METADATA_URL")
if creds == nil && metadataURLOverride != "" {
s, err := session.NewSession(&aws.Config{
Endpoint: aws.String(metadataURLOverride),
})
if err != nil {
return nil, err
}
func (g *S3Getter) getAWSConfig(region string, url *url.URL, staticCreds *credentials.StaticCredentialsProvider) (conf aws.Config, err error) {
var loadOptions []func(*config.LoadOptions) error
var creds aws.CredentialsProvider

creds = credentials.NewChainCredentials(
[]credentials.Provider{
&credentials.EnvProvider{},
&credentials.SharedCredentialsProvider{Filename: "", Profile: ""},
&ec2rolecreds.EC2RoleProvider{
Client: ec2metadata.New(s),
},
metadataURLOverride := os.Getenv("AWS_METADATA_URL")
if staticCreds == nil && metadataURLOverride != "" {
creds = ec2rolecreds.New(func(o *ec2rolecreds.Options) {
o.Client = imds.New(imds.Options{
Endpoint: metadataURLOverride,
ClientEnableState: imds.ClientEnabled,
})
})
} else if staticCreds != nil {
creds = staticCreds
}

if creds != nil {
conf.Endpoint = &url.Host
conf.S3ForcePathStyle = aws.Bool(true)
if url.Scheme == "http" {
conf.DisableSSL = aws.Bool(true)
}
loadOptions = append(loadOptions,
config.WithEC2IMDSClientEnableState(imds.ClientEnabled),
config.WithCredentialsProvider(creds),
config.WithEndpointResolverWithOptions(aws.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
return aws.Endpoint{URL: url.Host}, nil
},
)))
}

conf.Credentials = creds
if region != "" {
conf.Region = aws.String(region)
loadOptions = append(loadOptions, config.WithRegion(region))
}

conf = conf.WithCredentialsChainVerboseErrors(true)
return conf, nil
return config.LoadDefaultConfig(g.Context(), loadOptions...)
}

func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.Credentials, err error) {
func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, creds *credentials.StaticCredentialsProvider, err error) {
// This just check whether we are dealing with S3 or
// any other S3 compliant service. S3 has a predictable
// url as others do not
if strings.HasSuffix(u.Host, ".amazonaws.com") {
var awsDomain *string
for _, partition := range endpoints.DefaultPartitions() {
if strings.HasSuffix(u.Host, partition.DNSSuffix()) {
awsDomain = aws.String(partition.DNSSuffix())
break
}
}

if awsDomain != nil {
// Amazon S3 supports both virtual-hosted–style and path-style URLs to access a bucket, although path-style is deprecated
// In both cases few older regions supports dash-style region indication (s3-Region) even if AWS discourages their use.
// The same bucket could be reached with:
Expand All @@ -267,10 +267,10 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
// s3.amazonaws.com/bucket/path
// s3-region.amazonaws.com/bucket/path

hostParts := strings.Split(u.Host, ".")
hostParts := strings.Split(strings.TrimSuffix(u.Host, *awsDomain), ".")
switch len(hostParts) {
// path-style
case 3:
case 2:
// Parse the region out of the first part of the host
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")
if region == "" {
Expand All @@ -284,7 +284,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
bucket = pathParts[1]
path = pathParts[2]
// vhost-style, dash region indication
case 4:
case 3:
// Parse the region out of the second part of the host
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[1], "s3-"), "s3")
if region == "" {
Expand All @@ -299,7 +299,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
bucket = hostParts[0]
path = pathParts[1]
//vhost-style, dot region indication
case 5:
case 4:
region = hostParts[2]
pathParts := strings.SplitN(u.Path, "/", 2)
if len(pathParts) < 2 {
Expand All @@ -310,7 +310,7 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
path = pathParts[1]

}
if len(hostParts) < 3 || len(hostParts) > 5 {
if len(hostParts) < 2 || len(hostParts) > 4 {
err = fmt.Errorf("URL is not a valid S3 URL")
return
}
Expand All @@ -335,40 +335,36 @@ func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path, version string, c
_, hasAwsSecret := u.Query()["aws_access_key_secret"]
_, hasAwsToken := u.Query()["aws_access_token"]
if hasAwsId || hasAwsSecret || hasAwsToken {
creds = credentials.NewStaticCredentials(
provider := credentials.NewStaticCredentialsProvider(
u.Query().Get("aws_access_key_id"),
u.Query().Get("aws_access_key_secret"),
u.Query().Get("aws_access_token"),
)
creds = &provider
}

return
}

func (g *S3Getter) newS3Client(
region string, url *url.URL, creds *credentials.Credentials,
) (*s3.S3, error) {
var sess *session.Session
region string, url *url.URL, creds *credentials.StaticCredentialsProvider,
) (*s3.Client, error) {
var err error
var cfg aws.Config

if profile := url.Query().Get("aws_profile"); profile != "" {
var err error
sess, err = session.NewSessionWithOptions(session.Options{
Profile: profile,
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
return nil, err
}
cfg, err = config.LoadDefaultConfig(g.Context(),
config.WithSharedConfigProfile(profile),
)
} else {
config, err := g.getAWSConfig(region, url, creds)
if err != nil {
return nil, err
}
sess, err = session.NewSession(config)
if err != nil {
return nil, err
}
cfg, err = g.getAWSConfig(region, url, creds)
}

if err != nil {
return nil, err
}

return s3.New(sess), nil
return s3.NewFromConfig(cfg, func(opts *s3.Options) {
opts.UsePathStyle = true
}), nil
}
9 changes: 6 additions & 3 deletions get_s3_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
package getter

import (
"context"
"errors"
"net/url"
"os"
"path/filepath"
"testing"

"github.com/aws/aws-sdk-go/aws/awserr"
awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http"
)

// Note for external contributors: In order to run the S3 test suite, you will only be able to be run
Expand Down Expand Up @@ -87,7 +89,8 @@ func TestS3Getter_GetFile_badParams(t *testing.T) {
t.Fatalf("expected error, got none")
}

if reqerr, ok := err.(awserr.RequestFailure); !ok || reqerr.StatusCode() != 403 {
var respErr *awshttp.ResponseError
if errors.As(err, &respErr) && respErr.HTTPStatusCode() != 403 {
t.Fatalf("expected InvalidAccessKeyId error")
}
}
Expand Down Expand Up @@ -285,7 +288,7 @@ func TestS3Getter_Url(t *testing.T) {
return
}

credV, err := creds.Get()
credV, err := creds.Retrieve(context.Background())
if err != nil {
t.Fatalf("failed to get credentials: %s", err)
}
Expand Down
Loading