Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support default values in typing.List[dataclass] and typing.Dict[str, dataclass] #2603

Merged
merged 53 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
cf8567c
fix: set dataclass member as optional if default value is provided
mao3267 Jul 23, 2024
d33fae2
lint
mao3267 Jul 23, 2024
958d6dc
feat: handle nested dataclass conversion in JsonParamType
mao3267 Jul 27, 2024
5e2dd6e
fix: handle errors caused by NoneType default value
mao3267 Jul 27, 2024
06097a9
test: add nested dataclass unit tests
mao3267 Jul 27, 2024
f20b466
Sagemaker dict determinism (#2597)
samhita-alla Jul 23, 2024
2672d24
refactor(core): Enhance return type extraction logic (#2598)
pingsutw Jul 23, 2024
0e1dae1
Feat: Make exception raised by external command authenticator more ac…
fg91 Jul 24, 2024
73a90e4
Fix: Properly re-raise non-grpc exceptions during refreshing of proxy…
fg91 Jul 25, 2024
e838dae
validate idempotence token length in subsequent tasks (#2604)
samhita-alla Jul 26, 2024
537a3d7
Add nvidia-l4 gpu accelerator (#2608)
eapolinario Jul 26, 2024
ac4be0e
eliminate redundant literal conversion for `Iterator[JSON]` type (#2602)
samhita-alla Jul 26, 2024
45b04a6
[FlyteSchema] Fix numpy problems (#2619)
Future-Outlier Jul 29, 2024
39f2635
add nim plugin (#2475)
samhita-alla Jul 29, 2024
357877f
[Elastic/Artifacts] Pass through model card (#2575)
wild-endeavor Jul 29, 2024
758736f
Remove pyarrow as a direct dependency (#2228)
thomasjpfan Jul 29, 2024
25381ce
Boolean flag to show local container logs to the terminal (#2521)
aditya7302 Jul 29, 2024
dc575e4
Enable Ray Fast Register (#2606)
fiedlerNr9 Jul 29, 2024
87bfcc7
[Artifacts/Elastic] Skip partitions (#2620)
wild-endeavor Jul 29, 2024
5614fe3
Install flyteidl from master in plugins tests (#2621)
eapolinario Jul 30, 2024
1d25882
Using ParamSpec to show underlying typehinting (#2617)
JackUrb Jul 30, 2024
e6606ff
Support ArrayNode mapping over Launch Plans (#2480)
pvditt Jul 31, 2024
20c8250
Richer printing for some artifact objects (#2624)
wild-endeavor Jul 31, 2024
2ef9b08
ci: Add Python 3.9 to build matrix (#2622)
pingsutw Jul 31, 2024
83b2638
bump (#2627)
wild-endeavor Jul 31, 2024
414052a
Added alt prefix head to FlyteFile.new_remote (#2601)
pryce-turner Jul 31, 2024
1e7bfbe
Feature gate for FlyteMissingReturnValueException (#2623)
pingsutw Jul 31, 2024
9129608
Remove use of multiprocessing from the OAuth client (#2626)
rdeaton-freenome Jul 31, 2024
9895b28
Update codespell in precommit to version 2.3.0 (#2630)
eapolinario Jul 31, 2024
89ba90f
Fix Snowflake Agent Bug (#2605)
Future-Outlier Jul 31, 2024
399e19e
run test_missing_return_value on python 3.10+ (#2637)
pingsutw Aug 1, 2024
c95b987
[Elastic] Fix context usage and apply fix to fork method (#2628)
wild-endeavor Aug 1, 2024
b5e1a33
Add flytekit-omegaconf plugin (#2299)
mg515 Aug 1, 2024
d6bcde5
Adds extra-index-url to default image builder (#2636)
thomasjpfan Aug 1, 2024
8ff53b9
reference_task should inherit from PythonTask (#2643)
pingsutw Aug 2, 2024
0b48963
Fix Get Agent Secret Using Key (#2644)
Future-Outlier Aug 2, 2024
17fa0a6
fix: prevent converting Flyte types as custom dataclasses
mao3267 Aug 2, 2024
c01b0f6
fix: add None to output type
mao3267 Aug 2, 2024
8fa9a04
Merge remote-tracking branch 'origin' into handle-dataclass-default-v…
mao3267 Aug 2, 2024
a6232a4
Merge branch 'master' of https://github.com/mao3267/flytekit into han…
mao3267 Aug 9, 2024
64cc5b8
test: add unit test for nested dataclass inputs
mao3267 Aug 13, 2024
30d1222
test: add unit tests for nested dataclass, dataclass default value as…
mao3267 Aug 13, 2024
c285e34
fix: handle NoneType as default value of list type dataclass members
mao3267 Aug 13, 2024
4d2646c
fix: add comments for `has_nested_dataclass` function
mao3267 Aug 13, 2024
9de3f07
fix: make lint
mao3267 Aug 17, 2024
ec7ae5e
Merge branch 'master' of https://github.com/mao3267/flytekit into han…
mao3267 Aug 17, 2024
6013ad1
Merge remote-tracking branch 'origin' into handle-dataclass-default-v…
mao3267 Aug 25, 2024
b0162cd
fix: update tests regarding input through file and pipe
mao3267 Aug 25, 2024
d2ecd1f
Merge branch 'master' of https://github.com/mao3267/flytekit into han…
mao3267 Aug 25, 2024
7db88f3
Make JsonParamType convert faster
Future-Outlier Aug 26, 2024
27fc2d1
make has_nested_dataclass func more clean and add tests for dataclass…
Future-Outlier Aug 26, 2024
9f09010
make logic more backward compatible
Future-Outlier Aug 26, 2024
cd8877f
fix: handle indexing errors in dict/list while checking nested datacl…
mao3267 Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):

expected_type = get_underlying_type(expected_type)
expected_fields_dict = {}

for f in dataclasses.fields(expected_type):
expected_fields_dict[f.name] = f.type

Expand Down Expand Up @@ -539,11 +540,13 @@ def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]:
field.type = self._get_origin_type_in_annotation(field.type)
return python_type

def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T:
def _fix_structured_dataset_type(self, python_type: Type[T], python_val: typing.Any) -> T | None:
# In python 3.7, 3.8, DataclassJson will deserialize Annotated[StructuredDataset, kwtypes(..)] to a dict,
# so here we convert it back to the Structured Dataset.
from flytekit.types.structured import StructuredDataset

if python_val is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when will python_val be None here? could we add a small test for it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When dataclass members of type typing.Dict or typing.List default to None, this can lead to exceptions when iterating the list/dict from L552-588. We do need to add a test here.

return python_val
if python_type == StructuredDataset and type(python_val) == dict:
return StructuredDataset(**python_val)
elif get_origin(python_type) is list:
Expand Down Expand Up @@ -575,9 +578,13 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
return None
return [self._make_dataclass_serializable(v, get_args(python_type)[0]) for v in cast(list, python_val)]

if hasattr(python_type, "__origin__") and get_origin(python_type) is dict:
if python_val is None:
return None
Comment on lines +581 to +587
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing, thank you for fixing this.

return {
k: self._make_dataclass_serializable(v, get_args(python_type)[1])
for k, v in cast(dict, python_val).items()
Expand Down
41 changes: 40 additions & 1 deletion flytekit/interaction/click_types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import dataclasses
import datetime
import enum
import json
import logging
import os
import pathlib
import typing
from typing import cast
from typing import cast, get_args

import rich_click as click
import yaml
Expand All @@ -22,6 +23,7 @@
from flytekit.types.file import FlyteFile
from flytekit.types.iterator.json_iterator import JSONIteratorTransformer
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.types.schema.types import FlyteSchema


def is_pydantic_basemodel(python_type: typing.Type) -> bool:
Expand Down Expand Up @@ -305,11 +307,48 @@ def convert(
if value is None:
raise click.BadParameter("None value cannot be converted to a Json type.")

FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema]

def has_nested_dataclass(t: typing.Type) -> bool:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some comment for this function. If I pass a dataclass (not flyte type) to this function, it will return True here as well. That seems like not correct.

"""
Recursively checks whether the given type or its nested types contain any dataclass.

This function is typically called with a dictionary or list type and will return True if
any of the nested types within the dictionary or list is a dataclass.

Note:
- A single dataclass will return True.
- The function specifically excludes certain Flyte types like FlyteFile, FlyteDirectory,
StructuredDataset, and FlyteSchema from being considered as dataclasses. This is because
these types are handled separately by Flyte and do not need to be converted to dataclasses.

Args:
t (typing.Type): The type to check for nested dataclasses.

Returns:
bool: True if the type or its nested types contain a dataclass, False otherwise.
"""

if dataclasses.is_dataclass(t):
# FlyteTypes is not supported now, we can support it in the future.
return t not in FLYTE_TYPES

return any(has_nested_dataclass(arg) for arg in get_args(t))

parsed_value = self._parse(value, param)

# We compare the origin type because the json parsed value for list or dict is always a list or dict without
# the covariant type information.
if type(parsed_value) == typing.get_origin(self._python_type) or type(parsed_value) == self._python_type:
if get_args(self._python_type) == ():
return parsed_value
Comment on lines +345 to +346
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment added by Vincent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for native list and dict types, which don't have covariant types. Therefore, we can not index the return value from get_args.

elif isinstance(parsed_value, list) and has_nested_dataclass(get_args(self._python_type)[0]):
j = JsonParamType(get_args(self._python_type)[0])
return [j.convert(v, param, ctx) for v in parsed_value]
elif isinstance(parsed_value, dict) and has_nested_dataclass(get_args(self._python_type)[1]):
j = JsonParamType(get_args(self._python_type)[1])
return {k: j.convert(v, param, ctx) for k, v in parsed_value.items()}

return parsed_value

if is_pydantic_basemodel(self._python_type):
Expand Down
3 changes: 3 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.json
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
},
"p": "None",
"q": "tests/flytekit/unit/cli/pyflyte/testdata",
"r": [{"i": 1, "a": ["h", "e"]}],
"s": {"x": {"i": 1, "a": ["h", "e"]}},
"t": {"i": [{"i":1,"a":["h","e"]}]},
"remote": "tests/flytekit/unit/cli/pyflyte/testdata",
"image": "tests/flytekit/unit/cli/pyflyte/testdata"
}
17 changes: 17 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/my_wf_input.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,22 @@ o:
- tests/flytekit/unit/cli/pyflyte/testdata/df.parquet
p: 'None'
q: tests/flytekit/unit/cli/pyflyte/testdata
r:
- i: 1
a:
- h
- e
s:
x:
i: 1
a:
- h
- e
t:
i:
- i: 1
a:
- h
- e
remote: tests/flytekit/unit/cli/pyflyte/testdata
image: tests/flytekit/unit/cli/pyflyte/testdata
6 changes: 6 additions & 0 deletions tests/flytekit/unit/cli/pyflyte/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,12 @@ def test_pyflyte_run_cli(workflow_file):
"Any",
"--q",
DIR_NAME,
"--r",
json.dumps([{"i": 1, "a": ["h", "e"]}]),
"--s",
json.dumps({"x": {"i": 1, "a": ["h", "e"]}}),
"--t",
json.dumps({"i": [{"i":1,"a":["h","e"]}]}),
],
catch_exceptions=False,
)
Expand Down
13 changes: 11 additions & 2 deletions tests/flytekit/unit/cli/pyflyte/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class MyDataclass(DataClassJsonMixin):
i: int
a: typing.List[str]

@dataclass
class NestedDataclass(DataClassJsonMixin):
i: typing.List[MyDataclass]

class Color(enum.Enum):
RED = "RED"
Expand All @@ -61,8 +64,11 @@ def print_all(
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
q: FlyteDirectory,
r: typing.List[MyDataclass],
s: typing.Dict[str, MyDataclass],
t: NestedDataclass,
):
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}")
print(f"{a}, {b}, {c}, {d}, {e}, {f}, {g}, {h}, {i}, {j}, {k}, {l}, {m}, {n}, {o}, {p}, {q}, {r}, {s}, {t}")


@task
Expand Down Expand Up @@ -93,14 +99,17 @@ def my_wf(
o: typing.Dict[str, typing.List[FlyteFile]],
p: typing.Any,
q: FlyteDirectory,
r: typing.List[MyDataclass],
s: typing.Dict[str, MyDataclass],
t: NestedDataclass,
remote: pd.DataFrame,
image: StructuredDataset,
m: dict = {"hello": "world"},
) -> Annotated[StructuredDataset, subset_cols]:
x = get_subset_df(df=remote) # noqa: shown for demonstration; users should use the same types between tasks
show_sd(in_sd=x)
show_sd(in_sd=image)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q)
print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p, q=q, r=r, s=s, t=t)
return x


Expand Down
Loading
Loading