Skip to content

Commit

Permalink
[Feature] Support positional arguments (flyteorg#2522)
Browse files Browse the repository at this point in the history
- Change the `inputs` and `outputs` attributes in the `Interface` class to `OrderedDict` to preserve the order.
- Write values in positional arguments to `kwargs`.
Resolves: flyteorg/flyte#5320
Signed-off-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
Signed-off-by: mao3267 <chenvincent610@gmail.com>
  • Loading branch information
MortalHappiness authored and mao3267 committed Jul 29, 2024
1 parent 9513e87 commit 2e5532b
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 20 deletions.
40 changes: 20 additions & 20 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,19 +1202,22 @@ def flyte_entity_call_handler(
#. Start a local execution - This means that we're not already in a local workflow execution, which means that
we should expect inputs to be native Python values and that we should return Python native values.
"""
# Sanity checks
# Only keyword args allowed
if len(args) > 0:
raise _user_exceptions.FlyteAssertion(
f"When calling tasks, only keyword args are supported. "
f"Aborting execution as detected {len(args)} positional args {args}"
)
# Make sure arguments are part of interface
for k, v in kwargs.items():
if k not in cast(SupportsNodeCreation, entity).python_interface.inputs:
raise AssertionError(
f"Received unexpected keyword argument '{k}' in function '{cast(SupportsNodeCreation, entity).name}'"
)
if k not in entity.python_interface.inputs:
raise AssertionError(f"Received unexpected keyword argument '{k}' in function '{entity.name}'")

# Check if we have more arguments than expected
if len(args) > len(entity.python_interface.inputs):
raise AssertionError(
f"Received more arguments than expected in function '{entity.name}'. Expected {len(entity.python_interface.inputs)} but got {len(args)}"
)

# Convert args to kwargs
for arg, input_name in zip(args, entity.python_interface.inputs.keys()):
if input_name in kwargs:
raise AssertionError(f"Got multiple values for argument '{input_name}' in function '{entity.name}'")
kwargs[input_name] = arg

ctx = FlyteContextManager.current_context()
if ctx.execution_state and (
Expand All @@ -1234,15 +1237,12 @@ def flyte_entity_call_handler(
child_ctx.execution_state
and child_ctx.execution_state.branch_eval_mode == BranchEvalMode.BRANCH_SKIPPED
):
if (
len(cast(SupportsNodeCreation, entity).python_interface.inputs) > 0
or len(cast(SupportsNodeCreation, entity).python_interface.outputs) > 0
):
output_names = list(cast(SupportsNodeCreation, entity).python_interface.outputs.keys())
if len(entity.python_interface.inputs) > 0 or len(entity.python_interface.outputs) > 0:
output_names = list(entity.python_interface.outputs.keys())
if len(output_names) == 0:
return VoidPromise(entity.name)
vals = [Promise(var, None) for var in output_names]
return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface)
return create_task_output(vals, entity.python_interface)
else:
return None
return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs)
Expand All @@ -1255,7 +1255,7 @@ def flyte_entity_call_handler(
cast(ExecutionParameters, child_ctx.user_space_params)._decks = []
result = cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs)

expected_outputs = len(cast(SupportsNodeCreation, entity).python_interface.outputs)
expected_outputs = len(entity.python_interface.outputs)
if expected_outputs == 0:
if result is None or isinstance(result, VoidPromise):
return None
Expand All @@ -1268,10 +1268,10 @@ def flyte_entity_call_handler(
if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or (
result is not None and expected_outputs == 1
):
return create_native_named_tuple(ctx, result, cast(SupportsNodeCreation, entity).python_interface)
return create_native_named_tuple(ctx, result, entity.python_interface)

raise AssertionError(
f"Expected outputs and actual outputs do not match."
f"Result {result}. "
f"Python interface: {cast(SupportsNodeCreation, entity).python_interface}"
f"Python interface: {entity.python_interface}"
)
123 changes: 123 additions & 0 deletions tests/flytekit/unit/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,3 +943,126 @@ def wf_with_input() -> typing.Optional[typing.List[int]]:
)

assert wf_with_input() == input_val

def test_positional_args_task():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf_pure_positional_args() -> int:
return t1(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return t1(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type


assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_workflow():
arg1 = 5
arg2 = 6
ret = 17

@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def sub_wf(x: int, y: int) -> int:
return t1(x=x, y=y)

@workflow
def wf_pure_positional_args() -> int:
return sub_wf(arg1, arg2)

@workflow
def wf_mixed_positional_and_keyword_args() -> int:
return sub_wf(arg1, y=arg2)

wf_pure_positional_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_pure_positional_args)
wf_mixed_positional_and_keyword_args_spec = get_serializable(OrderedDict(), serialization_settings, wf_mixed_positional_and_keyword_args)

arg1_binding = Scalar(primitive=Primitive(integer=arg1))
arg2_binding = Scalar(primitive=Primitive(integer=arg2))
output_type = LiteralType(simple=SimpleType.INTEGER)

assert wf_pure_positional_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_pure_positional_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_pure_positional_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[0].binding.value == arg1_binding
assert wf_mixed_positional_and_keyword_args_spec.template.nodes[0].inputs[1].binding.value == arg2_binding
assert wf_mixed_positional_and_keyword_args_spec.template.interface.outputs["o0"].type == output_type

assert wf_pure_positional_args() == ret
assert wf_mixed_positional_and_keyword_args() == ret

def test_positional_args_chained_tasks():
@task
def t1(x: int, y: int) -> int:
return x + y * 2

@workflow
def wf() -> int:
x = t1(2, y = 3)
y = t1(3, 4)
return t1(x, y = y)

assert wf() == 30

def test_positional_args_task_inputs_from_workflow_args():
@task
def t1(x: int, y: int, z: int) -> int:
return x + y * 2 + z * 3

@workflow
def wf(x: int, y: int) -> int:
return t1(x, y=y, z=3)

assert wf(1, 2) == 14

def test_unexpected_kwargs_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received unexpected keyword argument"):
t1(b=6)

def test_too_many_positional_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Received more arguments than expected"):
t1(1, 2)

def test_both_positional_and_keyword_args_task_raises_error():
@task
def t1(a: int) -> int:
return a

with pytest.raises(AssertionError, match="Got multiple values for argument"):
t1(1, a=2)

0 comments on commit 2e5532b

Please sign in to comment.