-
Notifications
You must be signed in to change notification settings - Fork 297
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support default values in typing.List[dataclass] and typing.Dict[str, dataclass] #2603
Changes from 51 commits
cf8567c
d33fae2
958d6dc
5e2dd6e
06097a9
f20b466
2672d24
0e1dae1
73a90e4
e838dae
537a3d7
ac4be0e
45b04a6
39f2635
357877f
758736f
25381ce
dc575e4
87bfcc7
5614fe3
1d25882
e6606ff
20c8250
2ef9b08
83b2638
414052a
1e7bfbe
9129608
9895b28
89ba90f
399e19e
c95b987
b5e1a33
d6bcde5
8ff53b9
0b48963
17fa0a6
c01b0f6
8fa9a04
a6232a4
64cc5b8
30d1222
c285e34
4d2646c
9de3f07
ec7ae5e
6013ad1
b0162cd
d2ecd1f
7db88f3
27fc2d1
9f09010
cd8877f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): | |
|
||
expected_type = get_underlying_type(expected_type) | ||
expected_fields_dict = {} | ||
|
||
for f in dataclasses.fields(expected_type): | ||
expected_fields_dict[f.name] = f.type | ||
|
||
|
@@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: | |
field.type = self._get_origin_type_in_annotation(field.type) | ||
return python_type | ||
|
||
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T: | ||
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None: | ||
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict, | ||
# so here we convert it back to the Structured Dataset. | ||
from flytekit.types.structured import StructuredDataset | ||
|
||
if python_val is None: | ||
return python_val | ||
if python_type == StructuredDataset and type(python_val) == dict: | ||
return StructuredDataset(**python_val) | ||
elif get_origin(python_type) is list: | ||
|
@@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t | |
return self._make_dataclass_serializable(python_val, get_args(python_type)[0]) | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is list: | ||
if python_val is None: | ||
return None | ||
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)] | ||
|
||
if hasattr(python_type, "__origin__") and get_origin(python_type) is dict: | ||
if python_val is None: | ||
return None | ||
Comment on lines
+581
to
+587
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is amazing, thank you for fixing this. |
||
return { | ||
k: self._make_dataclass_serializable(v, get_args(python_type)[1]) | ||
for k, v in cast(dict, python_val).items() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,12 @@ | ||
import dataclasses | ||
import datetime | ||
import enum | ||
import json | ||
import logging | ||
import os | ||
import pathlib | ||
import typing | ||
from typing import cast | ||
from typing import cast, get_args | ||
|
||
import rich_click as click | ||
import yaml | ||
|
@@ -22,6 +23,7 @@ | |
from flytekit.types.file import FlyteFile | ||
from flytekit.types.iterator.json_iterator import JSONIteratorTransformer | ||
from flytekit.types.pickle.pickle import FlytePickleTransformer | ||
from flytekit.types.schema.types import FlyteSchema | ||
|
||
|
||
def is_pydantic_basemodel(python_type: typing.Type) -> bool: | ||
|
@@ -305,11 +307,48 @@ def convert( | |
if value is None: | ||
raise click.BadParameter("None value cannot be converted to a Json type.") | ||
|
||
FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] | ||
|
||
def has_nested_dataclass(t: typing.Type) -> bool: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add some comment for this function. If I pass a dataclass (not flyte type) to this function, it will return |
||
""" | ||
Recursively checks whether the given type or its nested types contain any dataclass. | ||
|
||
This function is typically called with a dictionary or list type and will return True if | ||
any of the nested types within the dictionary or list is a dataclass. | ||
|
||
Note: | ||
- A single dataclass will return True. | ||
- The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory, | ||
StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because | ||
these types are handled separately by Flyte and do not need to be converted to dataclasses. | ||
|
||
Args: | ||
t (typing.Type): The type to check for nested dataclasses. | ||
|
||
Returns: | ||
bool: True if the type or its nested types contain a dataclass, False otherwise. | ||
""" | ||
|
||
if dataclasses.is_dataclass(t): | ||
# FlyteTypes is not supported now, we can support it in the future. | ||
return t not in FLYTE_TYPES | ||
|
||
return any(has_nested_dataclass(arg) for arg in get_args(t)) | ||
|
||
parsed_value = self._parse(value, param) | ||
|
||
# We compare the origin type because the json parsed value for list or dict is always a list or dict without | ||
# the covariant type information. | ||
if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type: | ||
if get_args(self._python_type) == (): | ||
return parsed_value | ||
Comment on lines
+345
to
+346
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment added by Vincent. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for native list and dict types, which don't have covariant types. Therefore, we can not index the return value from |
||
elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]): | ||
j = JsonParamType(get_args(self._python_type)[0]) | ||
return [j.convert(v, param, ctx) for v in parsed_value] | ||
elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]): | ||
j = JsonParamType(get_args(self._python_type)[1]) | ||
return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()} | ||
|
||
return parsed_value | ||
|
||
if is_pydantic_basemodel(self._python_type): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when will
python_val
be None here? could we add a small test for it?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When dataclass members of type
typing.Dict
ortyping.List
default toNone
, this can lead to exceptions when iterating the list/dict from L552-588. We do need to add a test here.