Skip to content

Commit 8e033a0

Browse files
committed
Round #4
1 parent f075695 commit 8e033a0

File tree

6 files changed

+122
-7
lines changed

6 files changed

+122
-7
lines changed

cli/cmd/errors.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ const (
6969
ErrClusterAlreadyDeleted = "cli.cluster_already_deleted"
7070
ErrFailedClusterStatus = "cli.failed_cluster_status"
7171
ErrClusterDoesNotExist = "cli.cluster_does_not_exist"
72+
ErrTensorFlowDirTooManyFiles = "cli.tensorflow_dir_too_many_files"
7273
)
7374

7475
func ErrorCLINotConfigured(env string) error {
@@ -299,3 +300,10 @@ func ErrorFailedClusterStatus(status clusterstate.Status, clusterName string, re
299300
Message: fmt.Sprintf("cluster %s in %s encountered an unexpected status %s, please try to delete the cluster with `cortex cluster down` or delete the cloudformation stacks manually in your AWS console %s", clusterName, region, string(status), getCloudFormationURL(clusterName, region)),
300301
})
301302
}
303+
304+
func ErrorTensorFlowDirTooManyFiles(count int32) error {
305+
return errors.WithStack(&errors.Error{
306+
Kind: ErrTensorFlowDirTooManyFiles,
307+
Message: fmt.Sprintf("more than %d many files found in tensorflow directory", count),
308+
})
309+
}

cli/cmd/local.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ import (
55
"fmt"
66
"path"
77
"path/filepath"
8+
"strings"
9+
"time"
810

911
"github.com/cortexlabs/cortex/cli/local"
12+
"github.com/cortexlabs/cortex/pkg/lib/aws"
1013
"github.com/cortexlabs/cortex/pkg/lib/debug"
1114
"github.com/cortexlabs/cortex/pkg/lib/errors"
1215
"github.com/cortexlabs/cortex/pkg/lib/exit"
@@ -18,6 +21,7 @@ import (
1821
"github.com/cortexlabs/cortex/pkg/lib/zip"
1922
"github.com/cortexlabs/cortex/pkg/operator/schema"
2023
"github.com/cortexlabs/cortex/pkg/types/spec"
24+
"github.com/cortexlabs/cortex/pkg/types/userconfig"
2125
"github.com/docker/docker/api/types"
2226
dockertypes "github.com/docker/docker/api/types"
2327
"github.com/docker/docker/api/types/filters"
@@ -140,6 +144,78 @@ var localCmd = &cobra.Command{
140144
},
141145
}
142146

147+
func cacheModel(api *userconfig.API) {
148+
if strings.HasPrefix(*api.Predictor.Model, "s3://") {
149+
150+
} else {
151+
152+
}
153+
}
154+
155+
func cacheModelFromS3(api *userconfig.API) (string, error) {
156+
awsClient, err := aws.NewFromEnvS3Path(*api.Predictor.Model)
157+
if err != nil {
158+
return "", err
159+
}
160+
161+
s3Objects, err := awsClient.ListPathPrefix(*api.Predictor.Model, 1001)
162+
if err != nil {
163+
return "", err
164+
}
165+
166+
if len(s3Objects) == 1001 {
167+
return "", ErrorTensorFlowDirTooManyFiles(1000)
168+
}
169+
var mostRecentUpdateDate *time.Time
170+
for _, obj := range s3Objects {
171+
mostRecentUpdateDate = obj.LastModified
172+
}
173+
174+
modelPathHash := hash.String(*api.Predictor.Model)
175+
modelDir := filepath.Join(*api.Predictor.Model, modelPathHash)
176+
modelVersionDir := filepath.Join(modelDir, mostRecentUpdateDate.Format("2006-01-02T15:04:05"))
177+
178+
if files.IsFile(filepath.Join(modelVersionDir, "_SUCCESS")) {
179+
return modelVersionDir, nil
180+
}
181+
182+
err = files.DeleteDir(modelDir)
183+
if err != nil {
184+
return "", err
185+
}
186+
187+
_, err = files.CreateDirIfMissing(modelVersionDir)
188+
if err != nil {
189+
return "", nil
190+
}
191+
192+
bucket, fullPathKey, err := aws.SplitS3Path(*api.Predictor.Model)
193+
if err != nil {
194+
return "", err
195+
}
196+
for _, obj := range s3Objects {
197+
if *obj.Size == 0 { // TODO test creation of empty files
198+
continue
199+
}
200+
201+
if strings.HasSuffix(*obj.Key, "/") {
202+
continue
203+
}
204+
205+
localKey := (*obj.Key)[len(fullPathKey):]
206+
fileBytes, err := awsClient.ReadBytesFromS3(bucket, *obj.Key)
207+
if err != nil {
208+
return "", err
209+
}
210+
211+
err = files.WriteFile(fileBytes, filepath.Join(modelVersionDir, localKey))
212+
if err != nil {
213+
return "", err
214+
}
215+
}
216+
return modelVersionDir, nil
217+
}
218+
143219
var localGet = &cobra.Command{
144220
Use: "local-get",
145221
Short: "local an application",

cli/local/config.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ package local
1818

1919
import (
2020
"context"
21-
"fmt"
2221
"os"
2322
"path/filepath"
2423
"strings"
@@ -35,6 +34,7 @@ var _cachedDockerClient *dockerclient.Client
3534
var CWD string
3635
var LocalDir string
3736
var LocalWorkspace string
37+
var ModelCacheDir string
3838

3939
func init() {
4040
cwd, err := os.Getwd()
@@ -58,7 +58,13 @@ func init() {
5858
}
5959

6060
LocalWorkspace = filepath.Join(LocalDir, "local_workspace")
61-
fmt.Println(LocalWorkspace)
61+
err = os.MkdirAll(LocalWorkspace, os.ModePerm)
62+
if err != nil {
63+
err := errors.Wrap(err, "unable to write to home directory", LocalDir)
64+
exit.Error(err)
65+
}
66+
67+
ModelCacheDir = filepath.Join(LocalDir, "model_cache")
6268
err = os.MkdirAll(LocalWorkspace, os.ModePerm)
6369
if err != nil {
6470
err := errors.Wrap(err, "unable to write to home directory", LocalDir)

pkg/operator/operator/validations.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func ValidateClusterAPIs(apis []userconfig.API, projectFileMap map[string][]byte
4242
}
4343

4444
for i := range apis {
45-
if err := spec.ValidateAPI(&apis[i], projectFileMap); err != nil {
45+
if err := spec.ValidateAPI(&apis[i], projectFileMap, "aws"); err != nil {
4646
return err
4747
}
4848
if err := validateK8s(&apis[i], config.Cluster, virtualServices, maxMem); err != nil {

pkg/types/spec/errors.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ const (
4949
ErrFieldNotSupportedByPredictorType = "spec.field_not_supported_by_predictor_type"
5050
ErrNoAvailableNodeComputeLimit = "spec.no_available_node_compute_limit"
5151
ErrCortexPrefixedEnvVarNotAllowed = "spec.cortex_prefixed_env_var_not_allowed"
52-
ErrLocalPathNotSupportedByProvider = "spec.local_path_not_supported_by_provider"
52+
ErrLocalPathNotSupportedByAWSProvider = "spec.local_path_not_supported_by_aws_provider"
5353
)
5454

5555
func ErrorMalformedConfig() error {
@@ -236,9 +236,9 @@ func ErrorCortexPrefixedEnvVarNotAllowed() error {
236236
})
237237
}
238238

239-
func ErrorLocalPathNotSupportedByProvider(provider string) error {
239+
func ErrorLocalModelPathNotSupportedByAWSProvider() error {
240240
return errors.WithStack(&errors.Error{
241-
Kind: ErrLocalPathNotSupportedByProvider,
242-
Message: fmt.Sprintf("environment variables starting with CORTEX_ are reserved"),
241+
Kind: ErrLocalPathNotSupportedByAWSProvider,
242+
Message: fmt.Sprintf("local model paths are not supported for aws provider, please specify an S3 path"),
243243
})
244244
}

pkg/types/spec/validations.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,23 @@ func validateTensorFlowPredictor(predictor *userconfig.Predictor, providerType s
469469
predictor.Model = pointer.String(path)
470470
}
471471
} else {
472+
if providerType == "aws" {
473+
return errors.Wrap(ErrorLocalModelPathNotSupportedByAWSProvider(), model, userconfig.ModelKey)
474+
}
472475

476+
if strings.HasPrefix(model, ".zip") {
477+
if err := files.CheckFile(model); err != nil {
478+
return errors.Wrap(err, userconfig.ModelKey)
479+
}
480+
} else {
481+
path, err := getTFServingExportFromLocalPath(model)
482+
if err != nil {
483+
return errors.Wrap(err, userconfig.ModelKey)
484+
} else if path == "" {
485+
return errors.Wrap(ErrorInvalidTensorFlowDir(model), userconfig.ModelKey)
486+
}
487+
predictor.Model = pointer.String(path)
488+
}
473489
}
474490

475491
return nil
@@ -492,7 +508,13 @@ func validateONNXPredictor(predictor *userconfig.Predictor, providerType string)
492508
return errors.Wrap(ErrorS3FileNotFound(model), userconfig.ModelKey)
493509
}
494510
} else {
511+
if providerType == "aws" {
512+
return errors.Wrap(ErrorLocalModelPathNotSupportedByAWSProvider(), model, userconfig.ModelKey)
513+
}
495514

515+
if err := files.CheckFile(model); err != nil {
516+
return errors.Wrap(err, userconfig.ModelKey)
517+
}
496518
}
497519

498520
if predictor.SignatureKey != nil {
@@ -567,6 +589,9 @@ func isValidTensorFlowS3Directory(path string, awsClient *aws.Client) bool {
567589
}
568590

569591
func getTFServingExportFromLocalPath(path string) (string, error) {
592+
if !files.IsDir(path) {
593+
return "", ErrorDirNotFoundOrEmpty(path)
594+
}
570595
paths, err := files.ListDirRecursive(path, true, files.IgnoreHiddenFiles, files.IgnoreHiddenFolders)
571596
if err != nil {
572597
return "", err

0 commit comments

Comments
 (0)