Skip to content

Commit 597d844

Browse files
committed
Support resource references in model target_vocab
1 parent b7e0aa4 commit 597d844

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pkg/workloads/tf_api/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
"trans_impls": {},
5151
"required_inputs": None,
5252
"metadata": None,
53+
"target_vocab_populated": None,
5354
}
5455

5556
DTYPE_TO_VALUE_KEY = {
@@ -197,7 +198,7 @@ def parse_response_proto(response_proto):
197198
and estimator["name"] in target_vocab_estimators
198199
and model["input"].get("target_vocab") is not None
199200
):
200-
prediction = model["input"]["target_vocab"][int(prediction)]
201+
prediction = local_cache["target_vocab_populated"][int(prediction)]
201202
else:
202203
prediction = util.cast(prediction, target_col_type)
203204

@@ -400,6 +401,11 @@ def start(args):
400401

401402
local_cache["required_inputs"] = tf_lib.get_base_input_columns(model["name"], ctx)
402403

404+
if model["input"].get("target_vocab") is not None:
405+
local_cache["target_vocab_populated"] = ctx.populate_values(
406+
model["input"]["target_vocab"], None, False
407+
)
408+
403409
else:
404410
if not os.path.isdir(args.model_dir):
405411
ctx.storage.download_and_unzip_external(api["external_model"]["path"], args.model_dir)

0 commit comments

Comments
 (0)