Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Echo task #2654

Merged
merged 11 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from flytekit.core import launch_plan as _annotated_launchplan
from flytekit.core import workflow as _annotated_workflow
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.interface import Interface, output_name_generator, transform_function_to_interface
from flytekit.core.pod_template import PodTemplate
from flytekit.core.python_function_task import PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
Expand Down Expand Up @@ -416,3 +416,30 @@ def wrapper(fn) -> ReferenceTask:
return ReferenceTask(project, domain, name, version, interface.inputs, interface.outputs)

return wrapper


class Echo(PythonTask):
"""
A task that simply echoes the inputs back to the user.
The task's inputs and outputs interface are the same.
FlytePropeller won't create a pod for this task, it will simply pass the inputs to the outputs.
https://github.com/flyteorg/flyte/blob/master/flyteplugins/go/tasks/plugins/testing/echo.go
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
"""

_TASK_TYPE = "echo"

def __init__(self, name: str, inputs: Optional[Dict[str, Type]] = None, **kwargs):
outputs = dict(zip(list(output_name_generator(len(inputs.values()))), inputs.values())) if inputs else None
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(
task_type=self._TASK_TYPE,
name=name,
interface=Interface(inputs=inputs, outputs=outputs),
**kwargs,
)

def execute(self, **kwargs) -> Any:
values = list(kwargs.values())
if len(values) == 1:
return values[0]
else:
return tuple(values)
39 changes: 39 additions & 0 deletions tests/flytekit/unit/core/test_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.condition import conditional
from flytekit.core.task import Echo
from flytekit.models.core.workflow import Node
from flytekit.tools.translator import get_serializable

Expand Down Expand Up @@ -495,3 +496,41 @@ def multiplier_2(my_input: float) -> float:

res = multiplier_2(my_input=10.0)
assert res == 20


def test_echo_in_condition():
echo1 = Echo(name="echo", inputs={"a": float})

@task()
def t1(radius: float) -> typing.Optional[float]:
return 2 * 3.14 * radius

@workflow
def wf1(radius: float) -> typing.Optional[float]:
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius >= 0.1) & (radius < 1.0))
.then(t1(radius=radius))
.else_()
.then(echo1(a=radius))
)

assert wf1(radius=1.8) == 1.8

echo2 = Echo(name="echo", inputs={"a": float, "b": float})

@task()
def t2(radius: float) -> (float, float):
return 2 * 3.14 * radius, 2 * 3.14 * radius

@workflow
def wf2(radius1: float, radius2: float) -> (float, float):
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return (
conditional("shape_properties_with_multiple_branches")
.if_((radius1 >= 0.1) & (radius1 < 1.0))
.then(t2(radius=radius2))
.else_()
.then(echo2(a=radius1, b=radius2))
)

assert wf2(radius1=1.8, radius2=1.8) == (1.8, 1.8)
Loading