Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow a hash method to be present for numpy arrays #2649

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions flytekit/types/numpy/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,37 @@
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.hash import HashMethod
from flytekit.core.type_engine import (
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
)
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


def extract_metadata(t: Type[np.ndarray]) -> Tuple[Type[np.ndarray], Dict[str, bool]]:
metadata = {}
metadata: dict = {}
metadata_set = False

if get_origin(t) is Annotated:
base_type, metadata = get_args(t)
if isinstance(metadata, OrderedDict):
return base_type, metadata
else:
raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type kwtypes.")
base_type, *annotate_args = get_args(t)

Check warning on line 26 in flytekit/types/numpy/ndarray.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/numpy/ndarray.py#L26

Added line #L26 was not covered by tests

for aa in annotate_args:
if isinstance(aa, OrderedDict):
if metadata_set:
raise TypeTransformerFailedError(f"Metadata {metadata} is already specified, cannot use {aa}.")
metadata = aa
metadata_set = True

Check warning on line 33 in flytekit/types/numpy/ndarray.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/numpy/ndarray.py#L31-L33

Added lines #L31 - L33 were not covered by tests
elif isinstance(aa, HashMethod):
continue

Check warning on line 35 in flytekit/types/numpy/ndarray.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/numpy/ndarray.py#L35

Added line #L35 was not covered by tests
else:
raise TypeTransformerFailedError(f"The metadata for {t} must be of type kwtypes or HashMethod.")
return base_type, metadata

Check warning on line 38 in flytekit/types/numpy/ndarray.py

View check run for this annotation

Codecov / codecov/patch

flytekit/types/numpy/ndarray.py#L37-L38

Added lines #L37 - L38 were not covered by tests

# Return the type itself if no metadata was found.
return t, metadata


Expand All @@ -37,26 +54,36 @@
def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
format=self.NUMPY_ARRAY_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

def to_literal(
self, ctx: FlyteContext, python_val: np.ndarray, python_type: Type[np.ndarray], expected: LiteralType
self,
ctx: FlyteContext,
python_val: np.ndarray,
python_type: Type[np.ndarray],
expected: LiteralType,
) -> Literal:
python_type, metadata = extract_metadata(python_type)

meta = BlobMetadata(
type=_core_types.BlobType(
format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
format=self.NUMPY_ARRAY_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

local_path = ctx.file_access.get_random_local_path() + ".npy"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save numpy array to file
np.save(file=local_path, arr=python_val, allow_pickle=metadata.get("allow_pickle", False))
np.save(
file=local_path,
arr=python_val,
allow_pickle=metadata.get("allow_pickle", False),
)
remote_path = ctx.file_access.put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

Expand Down
46 changes: 41 additions & 5 deletions tests/flytekit/unit/types/numpy/test_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import numpy as np
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
from flytekit import HashMethod, kwtypes, task, workflow
from flytekit.core.type_engine import TypeTransformerFailedError


@task
Expand Down Expand Up @@ -63,6 +65,35 @@ def t4(array: Annotated[np.ndarray, kwtypes(allow_pickle=True)]) -> int:
return array.size


def dummy_hash_array(arr: np.ndarray) -> str:
return "dummy"


@task
def t5_annotate_kwtypes_and_hash(
array: Annotated[
np.ndarray, kwtypes(allow_pickle=True), HashMethod(dummy_hash_array)
],
):
pass


@task
def t6_annotate_kwtypes_twice(
array: Annotated[
np.ndarray, kwtypes(allow_pickle=True), kwtypes(allow_pickle=False)
],
):
pass


@task
def t7_annotate_with_sth_strange(
array: Annotated[np.ndarray, (1, 2, 3)],
):
pass


@workflow
def wf():
array_1d = generate_numpy_1d()
Expand All @@ -72,10 +103,15 @@ def wf():
t2(array=array_2d)
t3(array=array_1d)
t4(array=array_dtype_object)
try:
generate_numpy_fails()
except Exception as e:
assert isinstance(e, TypeError)
t5_annotate_kwtypes_and_hash(array=array_1d)

if array_1d.is_ready:
with pytest.raises(TypeTransformerFailedError, match=r"Metadata OrderedDict.*'allow_pickle'.*True.* is already specified, cannot use OrderedDict.*'allow_pickle'.*False.*\."):
t6_annotate_kwtypes_twice(array=array_1d)
with pytest.raises(TypeTransformerFailedError, match=r"The metadata for typing.Annotated.*numpy\.ndarray.*1, 2, 3.* must be of type kwtypes or HashMethod\."):
t7_annotate_with_sth_strange(array=array_1d)
with pytest.raises(TypeError, match=r"The metadata for typing.Annotated.*numpy\.ndarray.*'allow_pickle'.*True.* must be of type kwtypes or HashMethod\."):
generate_numpy_fails()


@workflow
Expand Down
Loading