Skip to content

Improve signature def detection #460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pkg/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,4 @@ var (
TelemetryURL = "https://telemetry.cortexlabs.dev"

MaxClassesPerRequest = 75 // cloudwatch.GeMetricData can get up to 100 metrics per request, avoid multiple requests and have room for other stats

DefaultTFServingSignatureKey = "predict"
)
9 changes: 1 addition & 8 deletions pkg/operator/api/userconfig/apis.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"strings"

"github.com/aws/aws-sdk-go/service/s3"
"github.com/cortexlabs/cortex/pkg/consts"
"github.com/cortexlabs/cortex/pkg/lib/aws"
cr "github.com/cortexlabs/cortex/pkg/lib/configreader"
"github.com/cortexlabs/cortex/pkg/lib/errors"
Expand Down Expand Up @@ -115,7 +114,7 @@ var apiValidation = &cr.StructValidation{
{
StructField: "SignatureKey",
StringValidation: &cr.StringValidation{
Default: consts.DefaultTFServingSignatureKey,
Required: true,
},
},
},
Expand Down Expand Up @@ -277,12 +276,6 @@ func (api *API) Validate(projectFileMap map[string][]byte) error {
}
}

if api.ModelFormat == TensorFlowModelFormat && api.TFServing == nil {
api.TFServing = &TFServingOptions{
SignatureKey: consts.DefaultTFServingSignatureKey,
}
}

if api.ModelFormat != TensorFlowModelFormat && api.TFServing != nil {
return errors.Wrap(ErrorTFServingOptionsForTFOnly(api.ModelFormat), Identify(api))
}
Expand Down
106 changes: 65 additions & 41 deletions pkg/workloads/cortex/tf_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@
"ctx": None,
"stub": None,
"api": None,
"metadata": None,
"signature_key": None,
"parsed_signature": None,
"model_metadata": None,
"request_handler": None,
"class_set": set(),
}
Expand Down Expand Up @@ -115,8 +117,8 @@ def after_request(response):


def create_prediction_request(sample):
signature_def = local_cache["metadata"]["signatureDef"]
signature_key = local_cache["api"]["tf_serving"]["signature_key"]
signature_def = local_cache["model_metadata"]["signatureDef"]
signature_key = local_cache["signature_key"]
prediction_request = predict_pb2.PredictRequest()
prediction_request.model_spec.name = "model"
prediction_request.model_spec.signature_name = signature_key
Expand Down Expand Up @@ -179,7 +181,7 @@ def run_predict(sample, debug=False):
if request_handler is not None and util.has_function(request_handler, "pre_inference"):
try:
prepared_sample = request_handler.pre_inference(
sample, local_cache["metadata"]["signatureDef"]
sample, local_cache["model_metadata"]["signatureDef"]
)
debug_obj("pre_inference", prepared_sample, debug)
except Exception as e:
Expand All @@ -196,7 +198,9 @@ def run_predict(sample, debug=False):

if request_handler is not None and util.has_function(request_handler, "post_inference"):
try:
result = request_handler.post_inference(result, local_cache["metadata"]["signatureDef"])
result = request_handler.post_inference(
result, local_cache["model_metadata"]["signatureDef"]
)
debug_obj("post_inference", result, debug)
except Exception as e:
raise UserRuntimeException(
Expand All @@ -207,9 +211,7 @@ def run_predict(sample, debug=False):


def validate_sample(sample):
signature = extract_signature(
local_cache["metadata"]["signatureDef"], local_cache["api"]["tf_serving"]["signature_key"]
)
signature = local_cache["parsed_signature"]
for input_name, _ in signature.items():
if input_name not in sample:
raise UserException('missing key "{}"'.format(input_name))
Expand Down Expand Up @@ -252,38 +254,56 @@ def predict(deployment_name, api_name):


def extract_signature(signature_def, signature_key):
if (
signature_def.get(signature_key) is None
or signature_def[signature_key].get("inputs") is None
):
raise UserException(
'unable to find "' + signature_key + "\" in model's signature definition"
)

metadata = {}
for input_name, input_metadata in signature_def[signature_key]["inputs"].items():
metadata[input_name] = {
logger.info("signature defs found in model: {}".format(signature_def))

available_keys = list(signature_def.keys())
if len(available_keys) == 0:
raise UserException("unable to find signature defs in model")

if signature_key is None:
if len(available_keys) == 1:
logger.info(
"signature_key was not configured by user, using signature key '{}' found in signature def map".format(
available_keys[0]
)
)
signature_key = available_keys[0]
else:
raise UserException(
"signature_key was not configured by user, please specify one the following keys '{}' found in signature def map".format(
"', '".join(available_keys)
)
)
else:
if signature_def.get(signature_key) is None:
possibilities_str = "key: '{}'".format(available_keys[0])
if len(available_keys) > 1:
possibilities_str = "keys: '{}'".format("', '".join(available_keys))

raise UserException(
"signature_key '{}' was not found in signature def map, but found the following {}".format(
signature_key, possibilities_str
)
)

signature_def_val = signature_def.get(signature_key)

if signature_def_val.get("inputs") is None:
raise UserException("unable to find 'inputs' in signature def '{}'".format(signature_key))

parsed_signature = {}
for input_name, input_metadata in signature_def_val["inputs"].items():
parsed_signature[input_name] = {
"shape": [int(dim["size"]) for dim in input_metadata["tensorShape"]["dim"]],
"type": DTYPE_TO_TF_TYPE[input_metadata["dtype"]].name,
}
return metadata
return signature_key, parsed_signature


@app.route("/<app_name>/<api_name>/signature", methods=["GET"])
def get_signature(app_name, api_name):
ctx = local_cache["ctx"]
api = local_cache["api"]

try:
metadata = extract_signature(
local_cache["metadata"]["signatureDef"],
local_cache["api"]["tf_serving"]["signature_key"],
)
except Exception as e:
logger.exception("failed to get signature")
return jsonify(error=str(e)), 404

response = {"signature": metadata}
signature = local_cache["parsed_signature"]
response = {"signature": signature}
return jsonify(response)


Expand Down Expand Up @@ -375,7 +395,7 @@ def start(args):
limit = 60
for i in range(limit):
try:
local_cache["metadata"] = run_get_model_metadata()
local_cache["model_metadata"] = run_get_model_metadata()
break
except Exception as e:
if i > 6:
Expand All @@ -385,14 +405,18 @@ def start(args):
sys.exit(1)

time.sleep(5)
logger.info(
"model_signature: {}".format(
extract_signature(
local_cache["metadata"]["signatureDef"],
local_cache["api"]["tf_serving"]["signature_key"],
)
)

signature_key = None
if api.get("tf_serving") is not None and api["tf_serving"].get("signature_key") is not None:
signature_key = api["tf_serving"]["signature_key"]

key, parsed_signature = extract_signature(
local_cache["model_metadata"]["signatureDef"], signature_key
)

local_cache["signature_key"] = key
local_cache["parsed_signature"] = parsed_signature
logger.info("model_signature: {}".format(local_cache["parsed_signature"]))
serve(app, listen="*:{}".format(args.port))


Expand Down