Skip to content

Commit febc293

Browse files
committed
Return expected input schema for prediction request errors
1 parent b0666fb commit febc293

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

pkg/workloads/tf_api/api.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,20 @@ def predict(deployment_name, api_name):
357357
api["name"]
358358
)
359359
)
360-
return prediction_failed(sample, str(e))
360+
361+
# Show signature def for external models (since we don't validate input)
362+
schemaStr = ""
363+
signature_def = local_cache["metadata"]["signatureDef"]
364+
if (
365+
not util.is_resource_ref(api["model"])
366+
and signature_def.get("predict") is not None # Just to be safe
367+
and signature_def["predict"].get("inputs") is not None # Just to be safe
368+
):
369+
schemaStr = "\n\nExpected shema:\n" + util.pp_str(
370+
signature_def["predict"]["inputs"]
371+
)
372+
373+
return prediction_failed(sample, str(e) + schemaStr)
361374

362375
predictions.append(result)
363376

0 commit comments

Comments
 (0)