From ae1832d623499472afb1871bf9f73e2e3cd60502 Mon Sep 17 00:00:00 2001 From: Davis Vann Bennett Date: Sun, 22 Dec 2024 18:34:30 +0100 Subject: [PATCH] handle singleton compressor / filters input --- src/zarr/core/array.py | 17 +++++++++++++---- src/zarr/core/group.py | 4 ++-- tests/test_store/test_zip.py | 2 +- tests/test_v2.py | 2 +- 4 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/zarr/core/array.py b/src/zarr/core/array.py index d5e01c1acc..fd3886a603 100644 --- a/src/zarr/core/array.py +++ b/src/zarr/core/array.py @@ -3,7 +3,7 @@ import json import warnings from asyncio import gather -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from dataclasses import dataclass, field from itertools import starmap from logging import getLogger @@ -3594,7 +3594,10 @@ async def create_array( config=config_parsed, ) else: - sub_codecs = _parse_chunk_encoding_v3(compression=compression, filters=filters, dtype=dtype) + array_array, array_bytes, bytes_bytes = _parse_chunk_encoding_v3( + compression=compression, filters=filters, dtype=dtype_parsed + ) + sub_codecs = (*array_array, array_bytes, *bytes_bytes) codecs_out: tuple[Codec, ...] if shard_shape_parsed is not None: sharding_codec = ShardingCodec(chunk_shape=chunk_shape_parsed, codecs=sub_codecs) @@ -3750,10 +3753,16 @@ def _parse_chunk_encoding_v3( if compression == "auto": out_bytes_bytes = default_bytes_bytes else: - out_bytes_bytes = tuple(compression) + if isinstance(compression, Mapping | Codec): + out_bytes_bytes = (compression,) + else: + out_bytes_bytes = tuple(compression) if filters == "auto": out_array_array = default_array_array else: - out_array_array = tuple(filters) + if isinstance(filters, Mapping | Codec): + out_array_array = (filters,) + else: + out_array_array = tuple(filters) return out_array_array, default_array_bytes, out_bytes_bytes diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index b8cc56c206..f3bc3f3eec 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -1064,8 +1064,8 @@ async def create_array( name=name, shape=shape, dtype=dtype, - chunk_shape=chunk_shape, - shard_shape=shard_shape, + chunks=chunk_shape, + shards=shard_shape, filters=filters, compression=compression, fill_value=fill_value, diff --git a/tests/test_store/test_zip.py b/tests/test_store/test_zip.py index df22b76e1e..c207adebe1 100644 --- a/tests/test_store/test_zip.py +++ b/tests/test_store/test_zip.py @@ -69,7 +69,7 @@ def test_api_integration(self, store: ZipStore) -> None: data = np.arange(10000, dtype=np.uint16).reshape(100, 100) z = root.create_array( - shape=data.shape, chunks=(10, 10), name="foo", dtype=np.uint16, fill_value=99 + shape=data.shape, chunk_shape=(10, 10), name="foo", dtype=np.uint16, fill_value=99 ) z[:] = data diff --git a/tests/test_v2.py b/tests/test_v2.py index 80897db8e5..e77edf56cc 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -88,7 +88,7 @@ async def test_v2_encode_decode(dtype): g.create_array( name="foo", shape=(3,), - chunks=(3,), + chunk_shape=(3,), dtype=dtype, fill_value=b"X", )