Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Echo task #2654

Merged
merged 11 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
Expand Down Expand Up @@ -416,3 +416,35 @@ def wrapper(fn) -> ReferenceTask:
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper


class Echo(PythonTask):
_TASK_TYPE = "echo"

def __init__(self, name: str, inputs: Optional[Dict[str, Type]] = None, **kwargs):
"""
A task that simply echoes the inputs back to the user.
The task's inputs and outputs interface are the same.
FlytePropeller won't create a pod for this task, it will simply pass the inputs to the outputs.
https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/plugins/testing/echo.go
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow up, I think this should link out to a docs around "How to enable echo tasks". This way a user can look at the docstring and know to contact their platform engineer to enable the feature.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! good call.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create another follow-up PR to address this bad error.

image

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just updated the docstring in this PR too


:param name: The name of the task.
:param inputs: Name and type of inputs specified as a dictionary.
e.g. {"a": int, "b": str}.
:param kwargs: All other args required by the parent type - PythonTask.

"""
outputs = dict(zip(output_name_generator(len(inputs)), inputs.values())) if inputs else None
super().__init__(
task_type=self._TASK_TYPE,
name=name,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

def execute(self, **kwargs) -> Any:
values = list(kwargs.values())
if len(values) == 1:
return values[0]
else:
return tuple(values)
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.task import Echo
from flytekit.models.core.workflow import Node
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -495,3 +496,41 @@ def multiplier_2(my_input: float) -> float:

res = multiplier_2(my_input=10.0)
assert res == 20


def test_echo_in_condition():
echo1 = Echo(name="echo", inputs={"a": typing.Optional[float]})

@task()
def t1(radius: float) -> typing.Optional[float]:
return 2 * 3.14 * radius

@workflow
def wf1(radius: float) -> typing.Optional[float]:
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius >= 0.1) & (radius < 1.0))
.then(t1(radius=radius))
.else_()
.then(echo1(a=radius))
)

assert wf1(radius=1.8) == 1.8

echo2 = Echo(name="echo", inputs={"a": float, "b": float})

@task()
def t2(radius: float) -> typing.Tuple[float, float]:
return 2 * 3.14 * radius, 2 * 3.14 * radius

@workflow
def wf2(radius1: float, radius2: float) -> typing.Tuple[float, float]:
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius1 >= 0.1) & (radius1 < 1.0))
.then(t2(radius=radius2))
.else_()
.then(echo2(a=radius1, b=radius2))
)

assert wf2(radius1=1.8, radius2=1.8) == (1.8, 1.8)
Loading