Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Elastic/Artifacts] Pass through model card #2575

Merged
merged 8 commits into from
Jul 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
Expand Down Expand Up @@ -240,6 +241,7 @@ class ElasticWorkerResult(NamedTuple):

return_value: Any
decks: List[flytekit.Deck]
om: OutputMetadata


def spawn_helper(
Expand Down Expand Up @@ -270,18 +272,21 @@ def spawn_helper(
raw_output_data_prefix=raw_output_prefix,
checkpoint_path=checkpoint_dest,
prev_checkpoint=checkpoint_src,
):
) as ctx:
fn = cloudpickle.loads(fn)

try:
return_val = fn(**kwargs)
omt = ctx.output_metadata_tracker
om = None
if omt:
om = omt.get(return_val)
except Exception as e:
# See explanation in `create_recoverable_error_file` why we check
# for recoverable errors here in the worker processes.
if isinstance(e, FlyteRecoverableException):
create_recoverable_error_file()
raise
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om)


def _convert_run_policy_to_flyte_idl(run_policy: RunPolicy) -> kubeflow_common.RunPolicy:
Expand Down Expand Up @@ -460,10 +465,12 @@ def fn_partial():
# Rank 0 returns the result of the task function
if 0 in out:
# For rank 0, we transfer the decks created in the worker process to the parent process
ctx = flytekit.current_context()
ctx = FlyteContextManager.current_context()
for deck in out[0].decks:
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)
if out[0].om:
ctx.output_metadata_tracker.add(out[0].return_value, out[0].om)

return out[0].return_value
else:
Expand Down
Loading