Skip to content

Commit

Permalink
Merge branch 'flyteorg:master' into input-through-file-and-pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
mao3267 authored Aug 2, 2024
2 parents 1372b6a + bcfbb80 commit aa49865
Show file tree
Hide file tree
Showing 26 changed files with 1,086 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ jobs:
# onnx-tensorflow needs a version of tensorflow that does not work with protobuf>4.
# The issue is being tracked on the tensorflow side in https://github.com/tensorflow/tensorflow/issues/53234#issuecomment-1330111693
# flytekit-onnx-tensorflow
- flytekit-omegaconf
- flytekit-openai
- flytekit-pandera
- flytekit-papermill
Expand Down
11 changes: 8 additions & 3 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,9 @@ def transform_interface_to_list_interface(
return Interface(inputs=map_inputs, outputs=map_outputs)


def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Docstring] = None) -> Interface:
def transform_function_to_interface(
fn: typing.Callable, docstring: Optional[Docstring] = None, is_reference_entity: bool = False
) -> Interface:
"""
From the annotations on a task function that the user should have provided, and the output names they want to use
for each output parameter, construct the TypedInterface object
Expand All @@ -382,9 +384,12 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
return_annotation = type_hints.get("return", None)

ctx = FlyteContextManager.current_context()

# Check if the function has a return statement at compile time locally.
# Skip it if the function is a reference task/workflow since it doesn't have a body.
if (
ctx.execution_state
# Only check if the task/workflow has a return statement at compile time locally.
not is_reference_entity
and ctx.execution_state
and ctx.execution_state.mode is None
# inspect module does not work correctly with Python <3.10.10. https://github.com/flyteorg/flyte/issues/5608
and sys.version_info >= (3, 10, 10)
Expand Down
6 changes: 3 additions & 3 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import TaskMetadata, TaskResolverMixin
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
Expand Down Expand Up @@ -371,7 +371,7 @@ def wrapper(fn: Callable[P, Any]) -> PythonFunctionTask[T]:
return wrapper


class ReferenceTask(ReferenceEntity, PythonFunctionTask): # type: ignore
class ReferenceTask(ReferenceEntity, PythonTask): # type: ignore
"""
This is a reference task, the body of the function passed in through the constructor will never be used, only the
signature of the function will be. The signature should also match the signature of the task you're referencing,
Expand Down Expand Up @@ -412,7 +412,7 @@ def reference_task(
"""

def wrapper(fn) -> ReferenceTask:
interface = transform_function_to_interface(fn)
interface = transform_function_to_interface(fn, is_reference_entity=True)
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper
2 changes: 1 addition & 1 deletion flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def is_terminal_phase(phase: TaskExecution.Phase) -> bool:


def get_agent_secret(secret_key: str) -> str:
return flytekit.current_context().secrets.get(secret_key)
return flytekit.current_context().secrets.get(key=secret_key)


def render_task_template(tt: TaskTemplate, file_prefix: str) -> TaskTemplate:
Expand Down
13 changes: 10 additions & 3 deletions flytekit/image_spec/default_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,21 @@ def create_docker_context(image_spec: ImageSpec, tmp_dir: Path):
requirements_uv_path = tmp_dir / "requirements_uv.txt"
requirements_uv_path.write_text("\n".join(uv_requirements))

pip_extra = f"--index-url {image_spec.pip_index}" if image_spec.pip_index else ""
uv_python_install_command = UV_PYTHON_INSTALL_COMMAND_TEMPLATE.substitute(PIP_EXTRA=pip_extra)
pip_extra_args = ""

if image_spec.pip_index:
pip_extra_args += f"--index-url {image_spec.pip_index}"
if image_spec.pip_extra_index_url:
extra_urls = [f"--extra-index-url {url}" for url in image_spec.pip_extra_index_url]
pip_extra_args += " ".join(extra_urls)

uv_python_install_command = UV_PYTHON_INSTALL_COMMAND_TEMPLATE.substitute(PIP_EXTRA=pip_extra_args)

if pip_requirements:
requirements_uv_path = tmp_dir / "requirements_pip.txt"
requirements_uv_path.write_text(os.linesep.join(pip_requirements))

pip_python_install_command = PIP_PYTHON_INSTALL_COMMAND_TEMPLATE.substitute(PIP_EXTRA=pip_extra)
pip_python_install_command = PIP_PYTHON_INSTALL_COMMAND_TEMPLATE.substitute(PIP_EXTRA=pip_extra_args)
else:
pip_python_install_command = ""

Expand Down
12 changes: 9 additions & 3 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import flytekit
from flytekit import PythonFunctionTask, Resources, lazy_module
from flytekit.configuration import SerializationSettings
from flytekit.core.context_manager import OutputMetadata
from flytekit.core.context_manager import FlyteContextManager, OutputMetadata
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import convert_resources_to_resource_model
from flytekit.exceptions.user import FlyteRecoverableException
Expand Down Expand Up @@ -429,13 +429,18 @@ def fn_partial():
"""Closure of the task function with kwargs already bound."""
try:
return_val = self._task_function(**kwargs)
core_context = FlyteContextManager.current_context()
omt = core_context.output_metadata_tracker
om = None
if omt:
om = omt.get(return_val)
except Exception as e:
# See explanation in `create_recoverable_error_file` why we check
# for recoverable errors here in the worker processes.
if isinstance(e, FlyteRecoverableException):
create_recoverable_error_file()
raise
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=None)
return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks, om=om)

launcher_target_func = fn_partial
launcher_args = ()
Expand Down Expand Up @@ -470,7 +475,8 @@ def fn_partial():
if not isinstance(deck, flytekit.deck.deck.TimeLineDeck):
ctx.decks.append(deck)
if out[0].om:
ctx.output_metadata_tracker.add(out[0].return_value, out[0].om)
core_context = FlyteContextManager.current_context()
core_context.output_metadata_tracker.add(out[0].return_value, out[0].om)

return out[0].return_value
else:
Expand Down
40 changes: 40 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
import typing
from dataclasses import dataclass
from unittest import mock
from typing_extensions import Annotated, cast
from flytekitplugins.kfpytorch.task import Elastic

from flytekit import Artifact

import pytest
import torch
Expand All @@ -11,6 +15,7 @@

import flytekit
from flytekit import task, workflow
from flytekit.core.context_manager import FlyteContext, FlyteContextManager, ExecutionState, ExecutionParameters, OutputMetadataTracker
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException

Expand Down Expand Up @@ -159,6 +164,41 @@ def wf():
assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html


class Card(object):
def __init__(self, text: str):
self.text = text

def serialize_to_string(self, ctx: FlyteContext, variable_name: str):
print(f"In serialize_to_string: {id(ctx)}")
return "card", "card"


@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_output_metadata_passing(start_method: str) -> None:
ea = Artifact(name="elastic-artf")

@task(
task_config=Elastic(start_method=start_method),
)
def train2() -> Annotated[str, ea]:
return ea.create_from("hello flyte", Card("## card"))

@workflow
def wf():
train2()

ctx = FlyteContext.current_context()
omt = OutputMetadataTracker()
with FlyteContextManager.with_context(
ctx.with_execution_state(ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_TASK_EXECUTION)).with_output_metadata_tracker(omt)
) as child_ctx:
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
# call execute directly so as to be able to get at the same FlyteContext object.
res = train2.execute()
om = child_ctx.output_metadata_tracker.get(res)
assert len(om.additional_items) == 1


@pytest.mark.parametrize(
"recoverable,start_method",
[
Expand Down
69 changes: 69 additions & 0 deletions plugins/flytekit-omegaconf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Flytekit OmegaConf Plugin

Flytekit python natively supports serialization of many data types for exchanging information between tasks.
The Flytekit OmegaConf Plugin extends these by the `DictConfig` type from the
[OmegaConf package](https://omegaconf.readthedocs.io/) as well as related types
that are being used by the [hydra package](https://hydra.cc/) for configuration management.

## Task example
```
from dataclasses import dataclass
import flytekitplugins.omegaconf # noqa F401
from flytekit import task, workflow
from omegaconf import DictConfig
@dataclass
class MySimpleConf:
_target_: str = "lightning_module.MyEncoderModule"
learning_rate: float = 0.0001
@task
def my_task(cfg: DictConfig) -> None:
print(f"Doing things with {cfg.learning_rate=}")
@workflow
def pipeline(cfg: DictConfig) -> None:
my_task(cfg=cfg)
if __name__ == "__main__":
from omegaconf import OmegaConf
cfg = OmegaConf.structured(MySimpleConf)
pipeline(cfg=cfg)
```

## Transformer configuration

The transformer can be set to one of three modes:

`Dataclass` - This mode should be used with a StructuredConfig and will reconstruct the config from the matching dataclass
during deserialisation in order to make typing information from the dataclass and continued validation thereof available.
This requires the dataclass definition to be available via python import in the Flyte execution environment in which
objects are (de-)serialised.

`DictConfig` - This mode will deserialize the config into a DictConfig object. In particular, dataclasses are translated
into DictConfig objects and only primitive types are being checked. The definition of underlying dataclasses for
structured configs is only required during the initial serialization for this mode.

`Auto` - This mode will try to deserialize according to the Dataclass mode and fall back to the DictConfig mode if the
dataclass definition is not available. This is the default mode.

You can set the transformer mode globally or for the current context only the following ways:
```python
from flytekitplugins.omegaconf import set_transformer_mode, set_local_transformer_mode, OmegaConfTransformerMode

# Set the global transformer mode using the new function
set_transformer_mode(OmegaConfTransformerMode.DictConfig)

# You can also the mode for the current context only
with set_local_transformer_mode(OmegaConfTransformerMode.Dataclass):
# This will use the Dataclass mode
pass
```

```note
Since the DictConfig is flattened and keys transformed into dot notation, the keys of the DictConfig must not contain
dots.
```
33 changes: 33 additions & 0 deletions plugins/flytekit-omegaconf/flytekitplugins/omegaconf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from contextlib import contextmanager

from flytekitplugins.omegaconf.config import OmegaConfTransformerMode
from flytekitplugins.omegaconf.dictconfig_transformer import DictConfigTransformer # noqa: F401
from flytekitplugins.omegaconf.listconfig_transformer import ListConfigTransformer # noqa: F401

_TRANSFORMER_MODE = OmegaConfTransformerMode.Auto


def set_transformer_mode(mode: OmegaConfTransformerMode) -> None:
"""Set the global serialization mode for OmegaConf objects."""
global _TRANSFORMER_MODE
_TRANSFORMER_MODE = mode


def get_transformer_mode() -> OmegaConfTransformerMode:
"""Get the global serialization mode for OmegaConf objects."""
return _TRANSFORMER_MODE


@contextmanager
def local_transformer_mode(mode: OmegaConfTransformerMode):
"""Context manager to set a local serialization mode for OmegaConf objects."""
global _TRANSFORMER_MODE
previous_mode = _TRANSFORMER_MODE
set_transformer_mode(mode)
try:
yield
finally:
set_transformer_mode(previous_mode)


__all__ = ["set_transformer_mode", "get_transformer_mode", "local_transformer_mode", "OmegaConfTransformerMode"]
15 changes: 15 additions & 0 deletions plugins/flytekit-omegaconf/flytekitplugins/omegaconf/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from enum import Enum


class OmegaConfTransformerMode(Enum):
"""
Operation Mode indicating whether a (potentially unannotated) DictConfig object or a structured config using the
underlying dataclass is returned.
Note: We define a single shared config across all transformers as recursive calls should refer to the same config
Note: The latter requires the use of structured configs.
"""

DictConfig = "DictConfig"
DataClass = "DataClass"
Auto = "Auto"
Loading

0 comments on commit aa49865

Please sign in to comment.