Skip to content

Add region to external models #161

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/applications/advanced/external-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ $ zip -r model.zip export/estimator
$ aws s3 cp model.zip s3://your-bucket/model.zip
```

3. Specify `model_path` in an API, e.g.
3. Specify `external_model` in an API, e.g.

```yaml
- kind: api
name: my-api
model_path: s3://your-bucket/model.zip
external_model:
path: s3://your-bucket/model.zip
region: us-west-2
compute:
replicas: 5
gpu: 1
Expand Down
4 changes: 3 additions & 1 deletion docs/applications/resources/apis.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ Serve models at scale and use them to build smarter applications.
- kind: api # (required)
name: <string> # API name (required)
model_name: <string> # name of a Cortex model (required)
model_path: <string> # path to a zipped model dir (optional)
external_model:
path: <string> # path to a zipped model dir (optional)
region: <string> # region of external model
compute:
replicas: <int> # number of replicas to launch (default: 1)
cpu: <string> # CPU request (default: Null)
Expand Down
4 changes: 3 additions & 1 deletion examples/external-model/app.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

- kind: api
name: iris
model_path: s3://cortex-examples/iris-model.zip
external_model:
path: s3://cortex-examples/iris-model.zip
region: us-west-2
compute:
replicas: 1
35 changes: 34 additions & 1 deletion pkg/lib/aws/s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,19 @@ func SplitS3aPath(s3aPath string) (string, string, error) {
if !IsValidS3aPath(s3aPath) {
return "", "", ErrorInvalidS3aPath(s3aPath)
}
fullPath := s3aPath[6:]
fullPath := s3aPath[len("s3a://"):]
slashIndex := strings.Index(fullPath, "/")
bucket := fullPath[0:slashIndex]
key := fullPath[slashIndex+1:]

return bucket, key, nil
}

func SplitS3Path(s3Path string) (string, string, error) {
if !IsValidS3Path(s3Path) {
return "", "", ErrorInvalidS3aPath(s3Path)
}
fullPath := s3Path[len("s3://"):]
slashIndex := strings.Index(fullPath, "/")
bucket := fullPath[0:slashIndex]
key := fullPath[slashIndex+1:]
Expand All @@ -291,6 +303,27 @@ func IsS3PrefixExternal(bucket string, prefix string, region string) (bool, erro
return hasPrefix, nil
}

func IsS3FileExternal(bucket string, key string, region string) (bool, error) {
sess := session.Must(session.NewSession(&aws.Config{
Region: aws.String(region),
}))

_, err := s3.New(sess).HeadObject(&s3.HeadObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})

if IsNotFoundErr(err) {
return false, nil
}

if err != nil {
return false, errors.Wrap(err, key)
}

return true, nil
}

func IsS3aPrefixExternal(s3aPath string, region string) (bool, error) {
bucket, prefix, err := SplitS3aPath(s3aPath)
if err != nil {
Expand Down
59 changes: 47 additions & 12 deletions pkg/operator/api/userconfig/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package userconfig

import (
"github.com/cortexlabs/cortex/pkg/lib/aws"
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/operator/api/resource"
Expand All @@ -26,10 +27,10 @@ type APIs []*API

type API struct {
ResourceFields
Model *string `json:"model" yaml:"model"`
ModelPath *string `json:"model_path" yaml:"model_path"`
Compute *APICompute `json:"compute" yaml:"compute"`
Tags Tags `json:"tags" yaml:"tags"`
Model *string `json:"model" yaml:"model"`
ExternalModel *ExternalModel `json:"external_model" yaml:"external_model"`
Compute *APICompute `json:"compute" yaml:"compute"`
Tags Tags `json:"tags" yaml:"tags"`
}

var apiValidation = &cr.StructValidation{
Expand All @@ -48,17 +49,40 @@ var apiValidation = &cr.StructValidation{
},
},
{
StructField: "ModelPath",
StringPtrValidation: &cr.StringPtrValidation{
Validator: cr.GetS3PathValidator(),
},
StructField: "ExternalModel",
StructValidation: externalModelFieldValidation,
},
apiComputeFieldValidation,
tagsFieldValidation,
typeFieldValidation,
},
}

type ExternalModel struct {
Path string `json:"path" yaml:"path"`
Region string `json:"region" yaml:"region"`
}

var externalModelFieldValidation = &cr.StructValidation{
DefaultNil: true,
StructFieldValidations: []*cr.StructFieldValidation{
{
StructField: "Path",
StringValidation: &cr.StringValidation{
Validator: cr.GetS3PathValidator(),
Required: true,
},
},
{
StructField: "Region",
StringValidation: &cr.StringValidation{
Default: aws.DefaultS3Region,
AllowedValues: aws.S3Regions.Slice(),
},
},
},
}

func (apis APIs) Validate() error {
for _, api := range apis {
if err := api.Validate(); err != nil {
Expand All @@ -80,12 +104,23 @@ func (apis APIs) Validate() error {
}

func (api *API) Validate() error {
if api.ModelPath == nil && api.Model == nil {
return errors.Wrap(ErrorSpecifyOnlyOneMissing("model_name", "model_path"), Identify(api))
if api.ExternalModel == nil && api.Model == nil {
return errors.Wrap(ErrorSpecifyOnlyOneMissing(ModelKey, ExternalModelKey), Identify(api))
}

if api.ModelPath != nil && api.Model != nil {
return errors.Wrap(ErrorSpecifyOnlyOne("model_name", "model_path"), Identify(api))
if api.ExternalModel != nil && api.Model != nil {
return errors.Wrap(ErrorSpecifyOnlyOne(ModelKey, ExternalModelKey), Identify(api))
}

if api.ExternalModel != nil {
bucket, key, err := aws.SplitS3Path(api.ExternalModel.Path)
if err != nil {
return errors.Wrap(err, Identify(api), ExternalModelKey, PathKey)
}

if ok, err := aws.IsS3FileExternal(bucket, key, api.ExternalModel.Region); err != nil || !ok {
return errors.Wrap(ErrorExternalModelNotFound(api.ExternalModel.Path), Identify(api), ExternalModelKey, PathKey)
}
}

return nil
Expand Down
5 changes: 3 additions & 2 deletions pkg/operator/api/userconfig/config_key.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ const (
DatasetComputeKey = "dataset_compute"

// API
ModelKey = "model"
ModelNameKey = "model_name"
ModelKey = "model"
ModelNameKey = "model_name"
ExternalModelKey = "external_model"
)
11 changes: 10 additions & 1 deletion pkg/operator/api/userconfig/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ const (
ErrEnvSchemaMismatch
ErrExtraResourcesWithExternalAPIs
ErrImplDoesNotExist
ErrExternalModelNotFound
)

var errorKinds = []string{
Expand Down Expand Up @@ -124,9 +125,10 @@ var errorKinds = []string{
"err_env_schema_mismatch",
"err_extra_resources_with_external_a_p_is",
"err_impl_does_not_exist",
"err_external_model_not_found",
}

var _ = [1]int{}[int(ErrImplDoesNotExist)-(len(errorKinds)-1)] // Ensure list length matches
var _ = [1]int{}[int(ErrExternalModelNotFound)-(len(errorKinds)-1)] // Ensure list length matches

func (t ErrorKind) String() string {
return errorKinds[t]
Expand Down Expand Up @@ -575,3 +577,10 @@ func ErrorImplDoesNotExist(path string) error {
message: fmt.Sprintf("%s: implementation file does not exist", path),
}
}

func ErrorExternalModelNotFound(path string) error {
return Error{
Kind: ErrExternalModelNotFound,
message: fmt.Sprintf("%s: file not found or inaccessible", path),
}
}
7 changes: 4 additions & 3 deletions pkg/operator/context/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ func getAPIs(config *userconfig.Config,
buf.WriteString(model.ID)
}

if apiConfig.ModelPath != nil {
modelName = *apiConfig.ModelPath
if apiConfig.ExternalModel != nil {
modelName = apiConfig.ExternalModel.Path
buf.WriteString(datasetVersion)
buf.WriteString(*apiConfig.ModelPath)
buf.WriteString(apiConfig.ExternalModel.Path)
buf.WriteString(apiConfig.ExternalModel.Region)
}

id := hash.Bytes(buf.Bytes())
Expand Down
2 changes: 1 addition & 1 deletion pkg/workloads/tf_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def start(args):

else:
if not os.path.isdir(args.model_dir):
ctx.storage.download_and_unzip_external(api["model_path"], args.model_dir)
ctx.storage.download_and_unzip_external(api["external_model"]["path"], args.model_dir)

channel = grpc.insecure_channel("localhost:" + str(args.tf_serve_port))
local_cache["stub"] = prediction_service_pb2_grpc.PredictionServiceStub(channel)
Expand Down