diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index be5cbc6255..f5d81b0636 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -476,10 +476,13 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: # Recursively construct the dataclass_type which contains the literal type of each field literal_type = {} + hints = typing.get_type_hints(t) # Get the type of each field from dataclass for field in t.__dataclass_fields__.values(): # type: ignore try: - literal_type[field.name] = TypeEngine.to_literal_type(field.type) + name = field.name + python_type = hints.get(name, field.type) + literal_type[name] = TypeEngine.to_literal_type(python_type) except Exception as e: logger.warning( "Field {} of type {} cannot be converted to a literal type. Error: {}".format( diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 57f6cddecf..9ff40b57d5 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2860,9 +2860,25 @@ class MyDataClass: assert literal_type is not None invalid_json_str = "{ unbalanced_braces" + with pytest.raises(Exception): Literal(scalar=Scalar(generic=_json_format.Parse(invalid_json_str, _struct.Struct()))) + @dataclass + class Fruit(DataClassJSONMixin): + name: str + + @dataclass + class NestedFruit(DataClassJSONMixin): + sub_fruit: Fruit + name: str + + literal_type = de.get_literal_type(NestedFruit) + dataclass_type = literal_type.structure.dataclass_type + assert dataclass_type["sub_fruit"].simple == SimpleType.STRUCT + assert dataclass_type["sub_fruit"].structure.dataclass_type["name"].simple == SimpleType.STRING + assert dataclass_type["name"].simple == SimpleType.STRING + def test_DataclassTransformer_to_literal(): @dataclass