Skip to content

Commit

Permalink
Zarr.save argument check (#2446)
Browse files Browse the repository at this point in the history
* add argument check

* test

* test zarr.save with multiple arrays

* fix path for zarr.save args

* improve test readability

* fix typing

* support other array types

Co-authored-by: Tom Augspurger <tom.augspurger88@gmail.com>

* support NDArrayLike

* fix typing

* fix typing

* fix typing

* format

---------

Co-authored-by: Tom Augspurger <tom.augspurger88@gmail.com>
  • Loading branch information
brokkoli71 and TomAugspurger authored Nov 5, 2024
1 parent f092351 commit d1075de
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from zarr.abc.store import Store
from zarr.core.array import Array, AsyncArray, get_array_metadata
from zarr.core.buffer import NDArrayLike
from zarr.core.common import (
JSON,
AccessModeLiteral,
Expand All @@ -31,7 +32,6 @@
from collections.abc import Iterable

from zarr.abc.codec import Codec
from zarr.core.buffer import NDArrayLike
from zarr.core.chunk_key_encodings import ChunkKeyEncoding

# TODO: this type could use some more thought
Expand Down Expand Up @@ -393,6 +393,8 @@ async def save_array(
_handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
or _default_zarr_version()
)
if not isinstance(arr, NDArrayLike):
raise TypeError("arr argument must be numpy or other NDArrayLike array")

mode = kwargs.pop("mode", None)
store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options)
Expand Down Expand Up @@ -447,16 +449,26 @@ async def save_group(
or _default_zarr_version()
)

for arg in args:
if not isinstance(arg, NDArrayLike):
raise TypeError(
"All arguments must be numpy or other NDArrayLike arrays (except store, path, storage_options, and zarr_format)"
)
for k, v in kwargs.items():
if not isinstance(v, NDArrayLike):
raise TypeError(f"Keyword argument '{k}' must be a numpy or other NDArrayLike array")

if len(args) == 0 and len(kwargs) == 0:
raise ValueError("at least one array must be provided")
aws = []
for i, arr in enumerate(args):
_path = f"{path}/arr_{i}" if path is not None else f"arr_{i}"
aws.append(
save_array(
store,
arr,
zarr_format=zarr_format,
path=f"{path}/arr_{i}",
path=_path,
storage_options=storage_options,
)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,33 @@ async def test_open_group_unspecified_version(
assert g2.metadata.zarr_format == zarr_format


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("n_args", [10, 1, 0])
@pytest.mark.parametrize("n_kwargs", [10, 1, 0])
def test_save(store: Store, n_args: int, n_kwargs: int) -> None:
data = np.arange(10)
args = [np.arange(10) for _ in range(n_args)]
kwargs = {f"arg_{i}": data for i in range(n_kwargs)}

if n_kwargs == 0 and n_args == 0:
with pytest.raises(ValueError):
save(store)
elif n_args == 1 and n_kwargs == 0:
save(store, *args)
array = open(store)
assert isinstance(array, Array)
assert_array_equal(array[:], data)
else:
save(store, *args, **kwargs) # type: ignore[arg-type]
group = open(store)
assert isinstance(group, Group)
for array in group.array_values():
assert_array_equal(array[:], data)
for k in kwargs.keys():
assert k in group
assert group.nmembers() == n_args + n_kwargs


def test_save_errors() -> None:
with pytest.raises(ValueError):
# no arrays provided
Expand All @@ -142,6 +169,10 @@ def test_save_errors() -> None:
with pytest.raises(ValueError):
# no arrays provided
save("data/group.zarr")
with pytest.raises(TypeError):
# mode is no valid argument and would get handled as an array
a = np.arange(10)
zarr.save("data/example.zarr", a, mode="w")


def test_open_with_mode_r(tmp_path: pathlib.Path) -> None:
Expand Down

0 comments on commit d1075de

Please sign in to comment.