Skip to content

Commit 4f9fd93

Browse files
Ivan Zhangdeliahu
Ivan Zhang
authored andcommitted
Use model directories for TensorFlow (#323)
1 parent 085a435 commit 4f9fd93

File tree

9 files changed

+141
-56
lines changed

9 files changed

+141
-56
lines changed

docs/deployments/packaging-models.md

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,20 @@ train_spec = tf.estimator.TrainSpec(train_input_fn, max_steps=1000)
2020
eval_spec = tf.estimator.EvalSpec(eval_input_fn, exporters=[exporter], name="estimator-eval")
2121

2222
tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)
23-
24-
# zip the estimator export dir (the exported path looks like iris/export/estimator/1562353043/)
25-
shutil.make_archive("tensorflow", "zip", os.path.join("iris/export/estimator"))
2623
```
2724

28-
Upload the zipped file to Amazon S3 using the AWS web console or CLI:
25+
Upload the exported version directory to Amazon S3 using the AWS web console or CLI:
2926

3027
```text
31-
$ aws s3 cp model.zip s3://my-bucket/model.zip
28+
$ aws s3 sync ./iris/export/estimator/156293432 s3://my-bucket/iris/156293432
3229
```
3330

3431
Reference your model in an `api`:
3532

3633
```yaml
3734
- kind: api
3835
name: my-api
39-
model: s3://my-bucket/model.zip
36+
model: s3://my-bucket/iris/156293432
4037
```
4138
4239
## ONNX

examples/iris/cortex.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
- kind: api
55
name: tensorflow
6-
model: s3://cortex-examples/iris/tensorflow.zip
6+
model: s3://cortex-examples/iris/tensorflow/1560263532
77
request_handler: handlers/tensorflow.py
88

99
- kind: api

examples/iris/models/tensorflow_model.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,3 @@ def my_model(features, labels, mode, params):
8888

8989
# zip the estimator export dir (the exported path looks like iris_tf_export/export/estimator/1562353043/)
9090
estimator_dir = EXPORT_DIR + "/export/estimator"
91-
shutil.make_archive("tensorflow", "zip", os.path.join(estimator_dir))
92-
93-
# clean up
94-
shutil.rmtree(EXPORT_DIR)

examples/sentiment/cortex.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@
33

44
- kind: api
55
name: classifier
6-
model: s3://cortex-examples/sentiment/bert.zip
6+
model: s3://cortex-examples/sentiment/1565392692
77
request_handler: sentiment.py

pkg/lib/aws/s3.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func IsS3PrefixExternal(bucket string, prefix string) (bool, error) {
318318
return hasPrefix, nil
319319
}
320320

321-
func IsS3FileExternal(bucket string, key string) (bool, error) {
321+
func IsS3FileExternal(bucket string, keys ...string) (bool, error) {
322322
region, err := GetBucketRegion(bucket)
323323
if err != nil {
324324
return false, err
@@ -328,17 +328,19 @@ func IsS3FileExternal(bucket string, key string) (bool, error) {
328328
Region: aws.String(region),
329329
}))
330330

331-
_, err = s3.New(sess).HeadObject(&s3.HeadObjectInput{
332-
Bucket: aws.String(bucket),
333-
Key: aws.String(key),
334-
})
331+
for _, key := range keys {
332+
_, err = s3.New(sess).HeadObject(&s3.HeadObjectInput{
333+
Bucket: aws.String(bucket),
334+
Key: aws.String(key),
335+
})
335336

336-
if IsNotFoundErr(err) {
337-
return false, nil
338-
}
337+
if IsNotFoundErr(err) {
338+
return false, nil
339+
}
339340

340-
if err != nil {
341-
return false, errors.Wrap(err, bucket, key)
341+
if err != nil {
342+
return false, errors.Wrap(err, bucket, key)
343+
}
342344
}
343345

344346
return true, nil
@@ -352,11 +354,29 @@ func IsS3aPathPrefixExternal(s3aPath string) (bool, error) {
352354
return IsS3PrefixExternal(bucket, prefix)
353355
}
354356

355-
func IsS3PathFileExternal(s3Path string) (bool, error) {
356-
bucket, key, err := SplitS3Path(s3Path)
357+
func IsS3PathPrefixExternal(s3Path string) (bool, error) {
358+
bucket, prefix, err := SplitS3Path(s3Path)
357359
if err != nil {
358360
return false, err
359361
}
362+
return IsS3PrefixExternal(bucket, prefix)
363+
}
364+
365+
func IsS3PathFileExternal(s3Paths ...string) (bool, error) {
366+
for _, s3Path := range s3Paths {
367+
bucket, key, err := SplitS3Path(s3Path)
368+
if err != nil {
369+
return false, err
370+
}
371+
exists, err := IsS3FileExternal(bucket, key)
372+
if err != nil {
373+
return false, err
374+
}
360375

361-
return IsS3FileExternal(bucket, key)
376+
if !exists {
377+
return false, nil
378+
}
379+
}
380+
381+
return true, nil
362382
}

pkg/operator/api/userconfig/apis.go

Lines changed: 50 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,29 @@ var apiValidation = &cr.StructValidation{
7777
},
7878
}
7979

80+
// IsValidS3Directory checks that the path contains a valid S3 directory for Tensorflow models
81+
// Must contain the following structure:
82+
// - 1523423423/ (version prefix, usually a timestamp)
83+
// - saved_model.pb
84+
// - variables/
85+
// - variables.index
86+
// - variables.data-00000-of-00001 (there are a variable number of these files)
87+
func IsValidS3Directory(path string) bool {
88+
if valid, err := aws.IsS3PathFileExternal(
89+
fmt.Sprintf("%s/saved_model.pb", path),
90+
fmt.Sprintf("%s/variables/variables.index", path),
91+
); err != nil || !valid {
92+
return false
93+
}
94+
95+
if valid, err := aws.IsS3PathPrefixExternal(
96+
fmt.Sprintf("%s/variables/variables.data-00000-of", path),
97+
); err != nil || !valid {
98+
return false
99+
}
100+
return true
101+
}
102+
80103
func (api *API) UserConfigStr() string {
81104
var sb strings.Builder
82105
sb.WriteString(api.ResourceFields.UserConfigStr())
@@ -118,24 +141,39 @@ func (apis APIs) Validate() error {
118141
func (api *API) Validate() error {
119142
if yaml.StartsWithEscapedAtSymbol(api.Model) {
120143
api.ModelFormat = TensorFlowModelFormat
121-
} else {
122-
if !aws.IsValidS3Path(api.Model) {
123-
return errors.Wrap(ErrorInvalidS3PathOrResourceReference(api.Model), Identify(api), ModelKey)
144+
if err := api.Compute.Validate(); err != nil {
145+
return errors.Wrap(err, Identify(api), ComputeKey)
124146
}
125147

126-
if api.ModelFormat == UnknownModelFormat {
127-
if strings.HasSuffix(api.Model, ".onnx") {
128-
api.ModelFormat = ONNXModelFormat
129-
} else if strings.HasSuffix(api.Model, ".zip") {
130-
api.ModelFormat = TensorFlowModelFormat
131-
} else {
132-
return errors.Wrap(ErrorUnableToInferModelFormat(), Identify(api))
133-
}
134-
}
148+
return nil
149+
}
150+
151+
if !aws.IsValidS3Path(api.Model) {
152+
return errors.Wrap(ErrorInvalidS3PathOrResourceReference(api.Model), Identify(api), ModelKey)
153+
}
135154

155+
switch api.ModelFormat {
156+
case ONNXModelFormat:
136157
if ok, err := aws.IsS3PathFileExternal(api.Model); err != nil || !ok {
137158
return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey)
138159
}
160+
case TensorFlowModelFormat:
161+
if !IsValidS3Directory(api.Model) {
162+
return errors.Wrap(ErrorInvalidTensorflowDir(api.Model), Identify(api), ModelKey)
163+
}
164+
default:
165+
switch {
166+
case strings.HasSuffix(api.Model, ".onnx"):
167+
api.ModelFormat = ONNXModelFormat
168+
if ok, err := aws.IsS3PathFileExternal(api.Model); err != nil || !ok {
169+
return errors.Wrap(ErrorExternalNotFound(api.Model), Identify(api), ModelKey)
170+
}
171+
case IsValidS3Directory(api.Model):
172+
api.ModelFormat = TensorFlowModelFormat
173+
default:
174+
return errors.Wrap(ErrorUnableToInferModelFormat(), Identify(api))
175+
}
176+
139177
}
140178

141179
if err := api.Compute.Validate(); err != nil {

pkg/operator/api/userconfig/config.go

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,6 @@ func (config *Config) ValidatePartial() error {
154154
}
155155

156156
func (config *Config) Validate(envName string) error {
157-
err := config.ValidatePartial()
158-
if err != nil {
159-
return err
160-
}
161-
162157
if config.App == nil {
163158
return ErrorMissingAppDefinition()
164159
}

pkg/operator/api/userconfig/errors.go

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ const (
8080
ErrInvalidS3PathOrResourceReference
8181
ErrUnableToInferModelFormat
8282
ErrExternalNotFound
83+
ErrInvalidTensorflowDir
8384
)
8485

8586
var errorKinds = []string{
@@ -133,9 +134,10 @@ var errorKinds = []string{
133134
"err_invalid_s3_path_or_resource_reference",
134135
"err_unable_to_infer_model_format",
135136
"err_external_not_found",
137+
"err_invalid_tensorflow_dir",
136138
}
137139

138-
var _ = [1]int{}[int(ErrExternalNotFound)-(len(errorKinds)-1)] // Ensure list length matches
140+
var _ = [1]int{}[int(ErrInvalidTensorflowDir)-(len(errorKinds)-1)] // Ensure list length matches
139141

140142
func (t ErrorKind) String() string {
141143
return errorKinds[t]
@@ -599,10 +601,30 @@ func ErrorExternalNotFound(path string) error {
599601
}
600602
}
601603

604+
var onnxExpectedStructMessage = `For ONNX models, the path should end in .onnx`
605+
606+
var tfExpectedStructMessage = `For TensorFlow models, the path should be a directory with the following structure:
607+
1523423423/ (version prefix, usually a timestamp)
608+
├── saved_model.pb
609+
└── variables/
610+
├── variables.index
611+
├── variables.data-00000-of-00003
612+
├── variables.data-00001-of-00003
613+
└── variables.data-00002-of-...`
614+
602615
func ErrorUnableToInferModelFormat() error {
616+
message := ModelFormatKey + " not specified, and could not be inferred\n" + onnxExpectedStructMessage + "\n" + tfExpectedStructMessage
603617
return Error{
604618
Kind: ErrUnableToInferModelFormat,
605-
message: "unable to infer " + ModelFormatKey + ": path to model should end in .zip for TensorFlow models, .onnx for ONNX models, or the " + ModelFormatKey + " key must be specified",
619+
message: message,
620+
}
621+
}
622+
func ErrorInvalidTensorflowDir(path string) error {
623+
message := "invalid TF export directory.\n"
624+
message += tfExpectedStructMessage
625+
return Error{
626+
Kind: ErrInvalidTensorflowDir,
627+
message: message,
606628
}
607629
}
608630

pkg/workloads/cortex/tf_api/api.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from cortex import consts
3232
from cortex.lib import util, tf_lib, package, Context
3333
from cortex.lib.log import get_logger
34+
from cortex.lib.storage import S3, LocalStorage
3435
from cortex.lib.exceptions import CortexException, UserRuntimeException, UserException
3536
from cortex.lib.context import create_transformer_inputs_from_map
3637

@@ -437,6 +438,23 @@ def get_signature(app_name, api_name):
437438
return jsonify(response)
438439

439440

441+
def download_dir_external(ctx, s3_path, local_path):
442+
util.mkdir_p(local_path)
443+
bucket_name, prefix = ctx.storage.deconstruct_s3_path(s3_path)
444+
storage_client = S3(bucket_name, client_config={})
445+
objects = [obj[len(prefix) + 1 :] for obj in storage_client.search(prefix=prefix)]
446+
prefix = prefix + "/" if prefix[-1] != "/" else prefix
447+
version = prefix.split("/")[-2]
448+
local_path = os.path.join(local_path, version)
449+
for obj in objects:
450+
if not os.path.exists(os.path.dirname(obj)):
451+
util.mkdir_p(os.path.join(local_path, os.path.dirname(obj)))
452+
453+
ctx.storage.download_file_external(
454+
bucket_name + "/" + os.path.join(prefix, obj), os.path.join(local_path, obj)
455+
)
456+
457+
440458
def validate_model_dir(model_dir):
441459
"""
442460
validates that model_dir has the expected directory tree.
@@ -489,23 +507,22 @@ def start(args):
489507
if api.get("request_handler") is not None:
490508
local_cache["request_handler"] = ctx.get_request_handler_impl(api["name"])
491509

492-
if not util.is_resource_ref(api["model"]):
493-
if not os.path.isdir(args.model_dir):
494-
ctx.storage.download_and_unzip_external(api["model"], args.model_dir)
510+
if not os.path.isdir(args.model_dir):
511+
if util.is_resource_ref(api["model"]):
512+
model_name = util.get_resource_ref(api["model"])
513+
model = ctx.models[model_name]
514+
ctx.storage.download_and_unzip(model["key"], args.model_dir)
515+
else:
516+
download_dir_external(ctx, api["model"], args.model_dir)
495517

496-
if args.only_download:
497-
return
498-
else:
518+
if args.only_download:
519+
return
520+
521+
if util.is_resource_ref(api["model"]):
499522
model_name = util.get_resource_ref(api["model"])
500523
model = ctx.models[model_name]
501524
estimator = ctx.estimators[model["estimator"]]
502525

503-
if not os.path.isdir(args.model_dir):
504-
ctx.storage.download_and_unzip(model["key"], args.model_dir)
505-
506-
if args.only_download:
507-
return
508-
509526
local_cache["model"] = model
510527
local_cache["estimator"] = estimator
511528
local_cache["target_col"] = ctx.columns[util.get_resource_ref(model["target_column"])]

0 commit comments

Comments
 (0)