diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index d0af81340d..e515309fed 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -31,6 +31,7 @@ TypeEngine, TypeTransformerFailedError, UnionTransformer, + _is_union_type, ) from flytekit.exceptions import user as _user_exceptions from flytekit.exceptions.user import FlytePromiseAttributeResolveException @@ -765,15 +766,26 @@ def binding_data_from_python_std( # This handles the case where the given value is the output of another task if isinstance(t_value, Promise): if not t_value.is_ready: - upstream_type = t_value.ref.node.flyte_entity.interface.outputs[t_value.ref.var].type + upstream_type = t_value.ref.node.flyte_entity.python_interface.outputs[t_value.ref.var] # if upstream type is a list of unions, make sure the downstream type is a list of unions # this is just a very limited test case for handling common map task type mis-matches so that we can show # the user more information without relying on the user to register with Admin to trigger the compiler - if upstream_type.collection_type and upstream_type.collection_type.union_type: - if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type): + if upstream_type is not t_value_type: + if _is_union_type(t_value_type): + sub_types = get_args(t_value_type) + if not any(upstream_type == t for t in sub_types): + raise AssertionError( + f"Expected type '{t_value_type}' does not include upstream type '{upstream_type}'" + ) + else: raise AssertionError( - f"Expected type {expected_literal_type}\n does not match upstream type {upstream_type}" + f"Expected type '{t_value_type}' does not match upstream type '{upstream_type}'" ) + # if upstream_type.collection_type and upstream_type.collection_type.union_type: + # if not (expected_literal_type.collection_type and expected_literal_type.collection_type.union_type): + # raise AssertionError( + # f"Expected type {expected_literal_type}\n does not match upstream type {upstream_type}" + # ) nodes.append(t_value.ref.node) # keeps track of upstream nodes return _literals_models.BindingData(promise=t_value.ref) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 2218ed430a..5948c0beef 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1477,7 +1477,7 @@ def _is_union_type(t): else: UnionType = None - return t is typing.Union or get_origin(t) is Union or UnionType and isinstance(t, UnionType) + return t is typing.Union or get_origin(t) is typing.Union or UnionType and isinstance(t, UnionType) class UnionTransformer(TypeTransformer[T]): diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index 7ad081de57..fa964a71ef 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -1,4 +1,5 @@ import functools +import os import typing from collections import OrderedDict from typing import List @@ -461,4 +462,5 @@ def wf(): dirs = mt(word=["one", "two", "three"]) consume_directories(dirs=dirs) - wf.compile() + with pytest.raises(AssertionError): + wf.compile() diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index bd24d47bb8..d7edc4d1c1 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -254,3 +254,30 @@ def test_prom_with_union_literals(): assert bd.scalar.union.stored_type.structure.tag == "int" bd = binding_data_from_python_std(ctx, lt, "hello", pt, []) assert bd.scalar.union.stored_type.structure.tag == "str" + + +def test_pro_with_mismatch_type(): + @task + def t1(a: int) -> int: + return a + + @task + def t2(a: str) -> str: + return a + + @workflow + def wf1(): + t2(a=t1(a=123)) + + with pytest.raises(AssertionError): + wf1.compile() + + @task + def t3(a: typing.Union[int, str]) -> int: + return a + + @workflow + def wf2(): + t3(a=t1(a=123)) + + wf2.compile()