Skip to content

Commit

Permalink
[Elastic/Artifacts] Pass through model card (#2575)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored Jul 29, 2024
1 parent b79c7a3 commit 955ae33
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
Expand Down Expand Up @@ -240,6 +241,7 @@ class ElasticWorkerResult(NamedTuple):

return_value: Any
decks: List[flytekit.Deck]
om: OutputMetadata


def spawn_helper(
Expand Down Expand Up @@ -270,18 +272,21 @@ def spawn_helper(
raw_output_data_prefix=raw_output_prefix,
checkpoint_path=checkpoint_dest,
prev_checkpoint=checkpoint_src,
):
) as ctx:
fn = cloudpickle.loads(fn)

try:
return_val = fn(**kwargs)
omt = ctx.output_metadata_tracker
om = None
if omt:
om = omt.get(return_val)
except Exception as e:
# See explanation in `create_recoverable_error_file` why we check
# for recoverable errors here in the worker processes.
if isinstance(e, FlyteRecoverableException):
create_recoverable_error_file()
raise
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om)


def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down Expand Up @@ -460,10 +465,12 @@ def fn_partial():
# Rank 0 returns the result of the task function
if 0 in out:
# For rank 0, we transfer the decks created in the worker process to the parent process
ctx = flytekit.current_context()
ctx = FlyteContextManager.current_context()
for deck in out[0].decks:
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)
if out[0].om:
ctx.output_metadata_tracker.add(out[0].return_value, out[0].om)

return out[0].return_value
else:
Expand Down

0 comments on commit 955ae33

Please sign in to comment.