From 357877f0fb55205f611a0fd37803903d078c459c Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 29 Jul 2024 10:57:59 -0700 Subject: [PATCH] [Elastic/Artifacts] Pass through model card (#2575) Signed-off-by: Yee Hing Tong Signed-off-by: mao3267 --- .../flytekitplugins/kfpytorch/task.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 0fab224fa2..ad9b5368b0 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -15,6 +15,7 @@ import flytekit from flytekit import PythonFunctionTask, Resources, lazy_module from flytekit.configuration import SerializationSettings +from flytekit.core.context_manager import FlyteContextManager, OutputMetadata from flytekit.core.pod_template import PodTemplate from flytekit.core.resources import convert_resources_to_resource_model from flytekit.exceptions.user import FlyteRecoverableException @@ -240,6 +241,7 @@ class ElasticWorkerResult(NamedTuple): return_value: Any decks: List[flytekit.Deck] + om: OutputMetadata def spawn_helper( @@ -270,18 +272,21 @@ def spawn_helper( raw_output_data_prefix=raw_output_prefix, checkpoint_path=checkpoint_dest, prev_checkpoint=checkpoint_src, - ): + ) as ctx: fn = cloudpickle.loads(fn) - try: return_val = fn(**kwargs) + omt = ctx.output_metadata_tracker + om = None + if omt: + om = omt.get(return_val) except Exception as e: # See explanation in `create_recoverable_error_file` why we check # for recoverable errors here in the worker processes. if isinstance(e, FlyteRecoverableException): create_recoverable_error_file() raise - return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks) + return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om) def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy: @@ -460,10 +465,12 @@ def fn_partial(): # Rank 0 returns the result of the task function if 0 in out: # For rank 0, we transfer the decks created in the worker process to the parent process - ctx = flytekit.current_context() + ctx = FlyteContextManager.current_context() for deck in out[0].decks: if not isinstance(deck, flytekit.deck.deck.TimeLineDeck): ctx.decks.append(deck) + if out[0].om: + ctx.output_metadata_tracker.add(out[0].return_value, out[0].om) return out[0].return_value else: