Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Zarr.save argument check #2446

Merged
merged 15 commits into from
Nov 5, 2024
14 changes: 13 additions & 1 deletion src/zarr/api/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ async def save_array(
_handle_zarr_version_or_format(zarr_version=zarr_version, zarr_format=zarr_format)
or _default_zarr_version()
)
if not isinstance(arr, np.ndarray):
brokkoli71 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("arr argument must be numpy array")

mode = kwargs.pop("mode", None)
store_path = await make_store_path(store, path=path, mode=mode, storage_options=storage_options)
Expand Down Expand Up @@ -447,16 +449,26 @@ async def save_group(
or _default_zarr_version()
)

for arg in args:
if not isinstance(arg, np.ndarray):
raise TypeError(
"All arguments must be numpy arrays (except store, path, storage_options, and zarr_format)"
)
for k, v in kwargs.items():
if not isinstance(v, np.ndarray):
raise TypeError(f"Keyword argument '{k}' must be a numpy array")

if len(args) == 0 and len(kwargs) == 0:
raise ValueError("at least one array must be provided")
aws = []
for i, arr in enumerate(args):
_path = f"{path}/arr_{i}" if path is not None else f"arr_{i}"
aws.append(
save_array(
store,
arr,
zarr_format=zarr_format,
path=f"{path}/arr_{i}",
path=_path,
storage_options=storage_options,
)
)
Expand Down
31 changes: 31 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,33 @@ async def test_open_group_unspecified_version(
assert g2.metadata.zarr_format == zarr_format


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
@pytest.mark.parametrize("n_args", [10, 1, 0])
@pytest.mark.parametrize("n_kwargs", [10, 1, 0])
def test_save(store: Store, n_args: int, n_kwargs: int) -> None:
data = np.arange(10)

if n_kwargs == 0 and n_args == 0:
with pytest.raises(ValueError):
save(store)
elif n_args == 1 and n_kwargs == 0:
save(store, data)
array = open(store)
assert isinstance(array, Array)
assert_array_equal(array, data)
else:
args = [np.arange(10) for _ in range(n_args)]
kwargs = {f"arg_{i}": data for i in range(n_kwargs)}
save(store, *args, **kwargs)
group = open(store)
assert isinstance(group, Group)
for array in group.array_values():
assert_array_equal(array[:], data)
for k in kwargs.keys():
assert k in group
assert group.nmembers() == n_args + n_kwargs


def test_save_errors() -> None:
with pytest.raises(ValueError):
# no arrays provided
Expand All @@ -142,6 +169,10 @@ def test_save_errors() -> None:
with pytest.raises(ValueError):
# no arrays provided
save("data/group.zarr")
with pytest.raises(TypeError):
# mode is no valid argument and would get handled as an array
a = np.arange(10)
zarr.save("data/example.zarr", a, mode="w")


def test_open_with_mode_r(tmp_path: pathlib.Path) -> None:
Expand Down