Skip to content

Commit

Permalink
Fix: Make sure decks created in elastic task workers are transferred …
Browse files Browse the repository at this point in the history
…to parent process (#1837)

* Transfer decks created in the worker process to the parent process

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Add test for decks in elastic tasks

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

* Update plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>

---------

Signed-off-by: Fabio Graetz <fabiograetz@googlemail.com>
  • Loading branch information
fg91 authored Sep 19, 2023
1 parent 474ffd0 commit a1e110e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 6 deletions.
36 changes: 30 additions & 6 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union

import cloudpickle
from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common
Expand Down Expand Up @@ -203,7 +203,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask)


def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs) -> Any:
class ElasticWorkerResult(NamedTuple):
"""
A named tuple representing the result of a torch elastic worker process.
Attributes:
return_value (Any): The value returned by the task function in the worker process.
decks (list[flytekit.Deck]): A list of flytekit Deck objects created in the worker process.
"""

return_value: Any
decks: List[flytekit.Deck]


def spawn_helper(
fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs
) -> ElasticWorkerResult:
"""Help to spawn worker processes.
The purpose of this function is to 1) be pickleable so that it can be used with
Expand All @@ -220,7 +235,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp
checkpoint_src (str): Location where the new checkpoint should be copied to.
Returns:
The return value of the received target function.
ElasticWorkerResult: A named tuple containing the return value of the task function and a list of
flytekit Deck objects created in the worker process.
"""
from flytekit.bin.entrypoint import setup_execution

Expand All @@ -231,7 +247,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp
):
fn = cloudpickle.loads(fn)
return_val = fn(**kwargs)
return return_val

return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)


class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]):
Expand Down Expand Up @@ -336,7 +353,8 @@ def _execute(self, **kwargs) -> Any:

def fn_partial():
"""Closure of the task function with kwargs already bound."""
return self._task_function(**kwargs)
return_val = self._task_function(**kwargs)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks)

launcher_target_func = fn_partial
launcher_args = ()
Expand Down Expand Up @@ -365,7 +383,13 @@ def fn_partial():
# `out` is a dictionary of rank (not local rank) -> result
# Rank 0 returns the result of the task function
if 0 in out:
return out[0]
# For rank 0, we transfer the decks created in the worker process to the parent process
ctx = flytekit.current_context()
for deck in out[0].decks:
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)

return out[0].return_value
else:
raise IgnoreOutputs()

Expand Down
38 changes: 38 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,41 @@ def test_task():
with mock.patch("torch.distributed.launcher.api.LaunchConfig", side_effect=LaunchConfig) as mock_launch_config:
test_task()
assert mock_launch_config.call_args[1]["rdzv_configs"] == rdzv_configs


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_deck(start_method: str) -> None:
"""Test that decks created in the main worker process are transferred to the parent process."""
world_size = 2

@task(
task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method),
disable_deck=False,
)
def train():
import os

ctx = flytekit.current_context()
deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}")
ctx.decks.append(deck)
default_deck = ctx.default_deck
default_deck.append("Hello from default deck")

@workflow
def wf():
train()

wf()

ctx = flytekit.current_context()

expected_deck_names = {"timeline", "default", "test-deck"}
found_deck_names = set(d.name for d in ctx.decks)

assert expected_deck_names.issubset(found_deck_names)

default_deck = [d for d in ctx.decks if d.name == "default"][0]
assert "Hello from default deck" == default_deck.html.strip()

test_deck = [d for d in ctx.decks if d.name == "test-deck"][0]
assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html

0 comments on commit a1e110e

Please sign in to comment.