Skip to content

Commit b5b8766

Browse files
vishalbolludeliahu
authored andcommitted
Read and validate from buckets in any regions (#1059)
(cherry picked from commit 468dce0)
1 parent 657531d commit b5b8766

File tree

6 files changed

+87
-28
lines changed

6 files changed

+87
-28
lines changed

cli/local/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func UpdateAPI(apiConfig *userconfig.API, cortexYAMLPath string, projectID strin
5050
if apiConfig.Predictor.Model != nil {
5151
localModelCache, err := CacheModel(*apiConfig.Predictor.Model, awsClient)
5252
if err != nil {
53-
return nil, "", errors.Wrap(err, userconfig.ModelKey, userconfig.PredictorKey, apiConfig.Identify())
53+
return nil, "", errors.Wrap(err, apiConfig.Identify(), userconfig.PredictorKey, userconfig.ModelKey)
5454
}
5555
apiSpec.LocalModelCache = localModelCache
5656
}

cli/local/model_cache.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,17 @@ import (
3636
func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache, error) {
3737
localModelCache := spec.LocalModelCache{}
3838

39+
awsClientForBucket, err := aws.NewFromClientS3Path(modelPath, awsClient)
40+
if err != nil {
41+
return nil, err
42+
}
43+
3944
if strings.HasPrefix(modelPath, "s3://") {
4045
bucket, prefix, err := aws.SplitS3Path(modelPath)
4146
if err != nil {
4247
return nil, err
4348
}
44-
hash, err := awsClient.HashS3Dir(bucket, prefix, nil)
49+
hash, err := awsClientForBucket.HashS3Dir(bucket, prefix, nil)
4550
if err != nil {
4651
return nil, err
4752
}
@@ -61,13 +66,13 @@ func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache,
6166
return &localModelCache, nil
6267
}
6368

64-
err := ResetModelCacheDir(modelDir)
69+
err = ResetModelCacheDir(modelDir)
6570
if err != nil {
6671
return nil, err
6772
}
6873

6974
if strings.HasPrefix(modelPath, "s3://") {
70-
err := downloadModel(modelPath, modelDir, awsClient)
75+
err := downloadModel(modelPath, modelDir, awsClientForBucket)
7176
if err != nil {
7277
return nil, err
7378
}
@@ -105,7 +110,7 @@ func CacheModel(modelPath string, awsClient *aws.Client) (*spec.LocalModelCache,
105110
return &localModelCache, nil
106111
}
107112

108-
func downloadModel(modelPath string, modelDir string, awsClient *aws.Client) error {
113+
func downloadModel(modelPath string, modelDir string, awsClientForBucket *aws.Client) error {
109114
fmt.Printf("○ downloading model %s ", modelPath)
110115
defer fmt.Print(" ✓\n")
111116
dotCron := cron.Run(print.Dot, nil, 2*time.Second)
@@ -118,7 +123,7 @@ func downloadModel(modelPath string, modelDir string, awsClient *aws.Client) err
118123

119124
if strings.HasSuffix(modelPath, ".zip") || strings.HasSuffix(modelPath, ".onnx") {
120125
localPath := filepath.Join(modelDir, filepath.Base(modelPath))
121-
err := awsClient.DownloadFileFromS3(bucket, prefix, localPath)
126+
err := awsClientForBucket.DownloadFileFromS3(bucket, prefix, localPath)
122127
if err != nil {
123128
return err
124129
}
@@ -134,7 +139,7 @@ func downloadModel(modelPath string, modelDir string, awsClient *aws.Client) err
134139
}
135140
} else {
136141
tfModelVersion := filepath.Base(prefix)
137-
err := awsClient.DownloadDirFromS3(bucket, prefix, filepath.Join(modelDir, tfModelVersion), true, nil)
142+
err := awsClientForBucket.DownloadDirFromS3(bucket, prefix, filepath.Join(modelDir, tfModelVersion), true, nil)
138143
if err != nil {
139144
return err
140145
}

pkg/lib/aws/aws.go

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/aws/aws-sdk-go/aws"
2424
"github.com/aws/aws-sdk-go/aws/credentials"
2525
"github.com/aws/aws-sdk-go/aws/session"
26+
"github.com/cortexlabs/cortex/pkg/lib/errors"
2627
)
2728

2829
type Client struct {
@@ -43,6 +44,22 @@ func NewFromCreds(region string, accessKeyID string, secretAccessKey string) (*C
4344
return New(region, creds)
4445
}
4546

47+
func NewFromClientS3Path(s3Path string, awsClient *Client) (*Client, error) {
48+
if !awsClient.IsAnonymous {
49+
if awsClient.AccessKeyID() == nil || awsClient.SecretAccessKey() == nil {
50+
return nil, ErrorUnexpectedMissingCredentials(awsClient.AccessKeyID(), awsClient.SecretAccessKey())
51+
}
52+
return NewFromCredsS3Path(s3Path, *awsClient.AccessKeyID(), *awsClient.SecretAccessKey())
53+
}
54+
55+
region, err := GetBucketRegionFromS3Path(s3Path)
56+
if err != nil {
57+
return nil, err
58+
}
59+
60+
return NewAnonymousClientWithRegion(region)
61+
}
62+
4663
func NewFromEnvS3Path(s3Path string) (*Client, error) {
4764
bucket, _, err := SplitS3Path(s3Path)
4865
if err != nil {
@@ -85,7 +102,7 @@ func New(region string, creds *credentials.Credentials) (*Client, error) {
85102
})
86103

87104
if err != nil {
88-
return nil, err
105+
return nil, errors.WithStack(err)
89106
}
90107

91108
return &Client{

pkg/lib/aws/errors.go

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,17 @@ import (
2525
)
2626

2727
const (
28-
ErrInvalidAWSCredentials = "aws.invalid_aws_credentials"
29-
ErrInvalidS3aPath = "aws.invalid_s3a_path"
30-
ErrInvalidS3Path = "aws.invalid_s3_path"
31-
ErrAuth = "aws.auth"
32-
ErrBucketInaccessible = "aws.bucket_inaccessible"
33-
ErrBucketNotFound = "aws.bucket_not_found"
34-
ErrInstanceTypeLimitIsZero = "aws.instance_type_limit_is_zero"
35-
ErrNoValidSpotPrices = "aws.no_valid_spot_prices"
36-
ErrReadCredentials = "aws.read_credentials"
37-
ErrECRExtractingCredentials = "aws.ecr_failed_credentials"
28+
ErrInvalidAWSCredentials = "aws.invalid_aws_credentials"
29+
ErrInvalidS3aPath = "aws.invalid_s3a_path"
30+
ErrInvalidS3Path = "aws.invalid_s3_path"
31+
ErrUnexpectedMissingCredentials = "aws.unexpected_missing_credentials"
32+
ErrAuth = "aws.auth"
33+
ErrBucketInaccessible = "aws.bucket_inaccessible"
34+
ErrBucketNotFound = "aws.bucket_not_found"
35+
ErrInstanceTypeLimitIsZero = "aws.instance_type_limit_is_zero"
36+
ErrNoValidSpotPrices = "aws.no_valid_spot_prices"
37+
ErrReadCredentials = "aws.read_credentials"
38+
ErrECRExtractingCredentials = "aws.ecr_failed_credentials"
3839
)
3940

4041
func IsNotFoundErr(err error) bool {
@@ -91,6 +92,22 @@ func ErrorInvalidS3Path(provided string) error {
9192
})
9293
}
9394

95+
func ErrorUnexpectedMissingCredentials(awsAccessKeyID *string, awsSecretAccessKey *string) error {
96+
var msg string
97+
if awsAccessKeyID == nil && awsSecretAccessKey == nil {
98+
msg = "aws access key id and aws secret access key are missing"
99+
} else if awsAccessKeyID == nil {
100+
msg = "aws access key id is missing"
101+
} else if awsSecretAccessKey == nil {
102+
msg = "aws secret access key is missing"
103+
}
104+
105+
return errors.WithStack(&errors.Error{
106+
Kind: ErrUnexpectedMissingCredentials,
107+
Message: msg,
108+
})
109+
}
110+
94111
func ErrorAuth() error {
95112
return errors.WithStack(&errors.Error{
96113
Kind: ErrAuth,

pkg/lib/aws/s3.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,15 @@ func IsValidS3aPath(s3aPath string) bool {
122122
return true
123123
}
124124

125+
func GetBucketRegionFromS3Path(s3Path string) (string, error) {
126+
bucket, _, err := SplitS3Path(s3Path)
127+
if err != nil {
128+
return "", err
129+
}
130+
131+
return GetBucketRegion(bucket)
132+
}
133+
125134
func GetBucketRegion(bucket string) (string, error) {
126135
sess := session.Must(session.NewSession()) // credentials are not necessary for this request, and will not be used
127136
region, err := s3manager.GetBucketRegion(aws.BackgroundContext(), sess, bucket, endpoints.UsWest2RegionID)

pkg/types/spec/validations.go

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -517,18 +517,24 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType t
517517
}
518518

519519
model := *predictor.Model
520+
521+
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
522+
if err != nil {
523+
return errors.Wrap(err, userconfig.ModelKey)
524+
}
525+
520526
if strings.HasPrefix(model, "s3://") {
521527
model, err := cr.S3PathValidator(model)
522528
if err != nil {
523529
return errors.Wrap(err, userconfig.ModelKey)
524530
}
525531

526532
if strings.HasSuffix(model, ".zip") {
527-
if ok, err := awsClient.IsS3PathFile(model); err != nil || !ok {
533+
if ok, err := awsClientForBucket.IsS3PathFile(model); err != nil || !ok {
528534
return errors.Wrap(ErrorS3FileNotFound(model), userconfig.ModelKey)
529535
}
530536
} else {
531-
path, err := getTFServingExportFromS3Path(model, awsClient)
537+
path, err := getTFServingExportFromS3Path(model, awsClientForBucket)
532538
if err != nil {
533539
return errors.Wrap(err, userconfig.ModelKey)
534540
} else if path == "" {
@@ -584,13 +590,18 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
584590
return errors.Wrap(ErrorInvalidONNXModelPath(), userconfig.ModelKey, model)
585591
}
586592

593+
awsClientForBucket, err := aws.NewFromClientS3Path(model, awsClient)
594+
if err != nil {
595+
return errors.Wrap(err, userconfig.ModelKey)
596+
}
597+
587598
if strings.HasPrefix(model, "s3://") {
588599
model, err := cr.S3PathValidator(model)
589600
if err != nil {
590601
return errors.Wrap(err, userconfig.ModelKey)
591602
}
592603

593-
if ok, err := awsClient.IsS3PathFile(model); err != nil || !ok {
604+
if ok, err := awsClientForBucket.IsS3PathFile(model); err != nil || !ok {
594605
return errors.Wrap(ErrorS3FileNotFound(model), userconfig.ModelKey)
595606
}
596607
} else {
@@ -620,8 +631,8 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType types.P
620631
return nil
621632
}
622633

623-
func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, error) {
624-
if isValidTensorFlowS3Directory(path, awsClient) {
634+
func getTFServingExportFromS3Path(path string, awsClientForBucket *aws.Client) (string, error) {
635+
if isValidTensorFlowS3Directory(path, awsClientForBucket) {
625636
return path, nil
626637
}
627638

@@ -630,7 +641,7 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
630641
return "", err
631642
}
632643

633-
objects, err := awsClient.ListS3PathDir(path, false, pointer.Int64(1000))
644+
objects, err := awsClientForBucket.ListS3PathDir(path, false, pointer.Int64(1000))
634645
if err != nil {
635646
return "", err
636647
} else if len(objects) == 0 {
@@ -652,7 +663,7 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
652663
}
653664

654665
possiblePath := "s3://" + filepath.Join(bucket, filepath.Join(keyParts[:len(keyParts)-1]...))
655-
if version >= highestVersion && isValidTensorFlowS3Directory(possiblePath, awsClient) {
666+
if version >= highestVersion && isValidTensorFlowS3Directory(possiblePath, awsClientForBucket) {
656667
highestVersion = version
657668
highestPath = possiblePath
658669
}
@@ -668,15 +679,15 @@ func getTFServingExportFromS3Path(path string, awsClient *aws.Client) (string, e
668679
// - variables/
669680
// - variables.index
670681
// - variables.data-00000-of-00001 (there are a variable number of these files)
671-
func isValidTensorFlowS3Directory(path string, awsClient *aws.Client) bool {
672-
if valid, err := awsClient.IsS3PathFile(
682+
func isValidTensorFlowS3Directory(path string, awsClientForBucket *aws.Client) bool {
683+
if valid, err := awsClientForBucket.IsS3PathFile(
673684
aws.JoinS3Path(path, "saved_model.pb"),
674685
aws.JoinS3Path(path, "variables/variables.index"),
675686
); err != nil || !valid {
676687
return false
677688
}
678689

679-
if valid, err := awsClient.IsS3PathPrefix(
690+
if valid, err := awsClientForBucket.IsS3PathPrefix(
680691
aws.JoinS3Path(path, "variables/variables.data-00000-of"),
681692
); err != nil || !valid {
682693
return false

0 commit comments

Comments
 (0)