Skip to content

Commit

Permalink
update error
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
  • Loading branch information
pingsutw committed Sep 3, 2024
1 parent 404c189 commit 8a36178
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
20 changes: 16 additions & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TypeEngine,
TypeTransformerFailedError,
UnionTransformer,
_is_union_type,
)
from flytekit.exceptions import user as _user_exceptions
from flytekit.exceptions.user import FlytePromiseAttributeResolveException
Expand Down Expand Up @@ -765,15 +766,26 @@ def binding_data_from_python_std(
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
if not t_value.is_ready:
upstream_type = t_value.ref.node.flyte_entity.interface.outputs[t_value.ref.var].type
upstream_type = t_value.ref.node.flyte_entity.python_interface.outputs[t_value.ref.var]
# if upstream type is a list of unions, make sure the downstream type is a list of unions
# this is just a very limited test case for handling common map task type mis-matches so that we can show
# the user more information without relying on the user to register with Admin to trigger the compiler
if upstream_type.collection_type and upstream_type.collection_type.union_type:
if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type):
if upstream_type is not t_value_type:
if _is_union_type(t_value_type):
sub_types = get_args(t_value_type)
if not any(upstream_type == t for t in sub_types):
raise AssertionError(
f"Expected type '{t_value_type}' does not include upstream type '{upstream_type}'"
)
else:
raise AssertionError(
f"Expected type {expected_literal_type}\n does not match upstream type {upstream_type}"
f"Expected type '{t_value_type}' does not match upstream type '{upstream_type}'"
)
# if upstream_type.collection_type and upstream_type.collection_type.union_type:
# if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type):
# raise AssertionError(
# f"Expected type {expected_literal_type}\n does not match upstream type {upstream_type}"
# )
nodes.append(t_value.ref.node) # keeps track of upstream nodes
return _literals_models.BindingData(promise=t_value.ref)

Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def _is_union_type(t):
else:
UnionType = None

return t is typing.Union or get_origin(t) is Union or UnionType and isinstance(t, UnionType)
return t is typing.Union or get_origin(t) is typing.Union or UnionType and isinstance(t, UnionType)


class UnionTransformer(TypeTransformer[T]):
Expand Down
4 changes: 3 additions & 1 deletion tests/flytekit/unit/core/test_array_node_map_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import os
import typing
from collections import OrderedDict
from typing import List
Expand Down Expand Up @@ -461,4 +462,5 @@ def wf():
dirs = mt(word=["one", "two", "three"])
consume_directories(dirs=dirs)

wf.compile()
with pytest.raises(AssertionError):
wf.compile()
27 changes: 27 additions & 0 deletions tests/flytekit/unit/core/test_promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,30 @@ def test_prom_with_union_literals():
assert bd.scalar.union.stored_type.structure.tag == "int"
bd = binding_data_from_python_std(ctx, lt, "hello", pt, [])
assert bd.scalar.union.stored_type.structure.tag == "str"


def test_pro_with_mismatch_type():
@task
def t1(a: int) -> int:
return a

@task
def t2(a: str) -> str:
return a

@workflow
def wf1():
t2(a=t1(a=123))

with pytest.raises(AssertionError):
wf1.compile()

@task
def t3(a: typing.Union[int, str]) -> int:
return a

@workflow
def wf2():
t3(a=t1(a=123))

wf2.compile()

0 comments on commit 8a36178

Please sign in to comment.