diff --git a/src/zarr/abc/codec.py b/src/zarr/abc/codec.py index 0836d878ae..1f452159ed 100644 --- a/src/zarr/abc/codec.py +++ b/src/zarr/abc/codec.py @@ -13,11 +13,10 @@ if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import ArraySpec + from zarr.array_spec import ArraySpec from zarr.indexing import SelectorTuple from zarr.metadata import ArrayMetadata - CodecInput = TypeVar("CodecInput", bound=NDBuffer | Buffer) CodecOutput = TypeVar("CodecOutput", bound=NDBuffer | Buffer) diff --git a/src/zarr/abc/store.py b/src/zarr/abc/store.py index e86fe5d07a..14566dfed2 100644 --- a/src/zarr/abc/store.py +++ b/src/zarr/abc/store.py @@ -2,7 +2,7 @@ from collections.abc import AsyncGenerator from typing import Protocol, runtime_checkable -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype from zarr.common import BytesLike, OpenMode @@ -30,7 +30,10 @@ def _check_writable(self) -> None: @abstractmethod async def get( - self, key: str, byte_range: tuple[int | None, int | None] | None = None + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: """Retrieve the value associated with a given key. @@ -47,7 +50,9 @@ async def get( @abstractmethod async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] + self, + prototype: BufferPrototype, + key_ranges: list[tuple[str, tuple[int | None, int | None]]], ) -> list[Buffer | None]: """Retrieve possibly partial values from given key_ranges. @@ -175,12 +180,16 @@ def close(self) -> None: # noqa: B027 @runtime_checkable class ByteGetter(Protocol): - async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ... + async def get( + self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + ) -> Buffer | None: ... @runtime_checkable class ByteSetter(Protocol): - async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: ... + async def get( + self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + ) -> Buffer | None: ... async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: ... diff --git a/src/zarr/array.py b/src/zarr/array.py index 698894ba0c..28b19f44f0 100644 --- a/src/zarr/array.py +++ b/src/zarr/array.py @@ -20,7 +20,7 @@ from zarr.abc.codec import Codec from zarr.abc.store import set_or_delete from zarr.attributes import Attributes -from zarr.buffer import Factory, NDArrayLike, NDBuffer +from zarr.buffer import BufferPrototype, NDArrayLike, NDBuffer, default_buffer_prototype from zarr.chunk_grids import RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, DefaultChunkKeyEncoding, V2ChunkKeyEncoding from zarr.codecs import BytesCodec @@ -414,8 +414,8 @@ async def _get_selection( self, indexer: Indexer, *, + prototype: BufferPrototype, out: NDBuffer | None = None, - factory: Factory.Create = NDBuffer.create, fields: Fields | None = None, ) -> NDArrayLike: # check fields are sensible @@ -432,7 +432,7 @@ async def _get_selection( f"shape of out argument doesn't match. Expected {indexer.shape}, got {out.shape}" ) else: - out_buffer = factory( + out_buffer = prototype.nd_buffer.create( shape=indexer.shape, dtype=out_dtype, order=self.order, @@ -444,7 +444,7 @@ async def _get_selection( [ ( self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, self.order), + self.metadata.get_chunk_spec(chunk_coords, self.order, prototype=prototype), chunk_selection, out_selection, ) @@ -456,14 +456,14 @@ async def _get_selection( return out_buffer.as_ndarray_like() async def getitem( - self, selection: Selection, *, factory: Factory.Create = NDBuffer.create + self, selection: Selection, *, prototype: BufferPrototype = default_buffer_prototype ) -> NDArrayLike: indexer = BasicIndexer( selection, shape=self.metadata.shape, chunk_grid=self.metadata.chunk_grid, ) - return await self._get_selection(indexer, factory=factory) + return await self._get_selection(indexer, prototype=prototype) async def _save_metadata(self, metadata: ArrayMetadata) -> None: to_save = metadata.to_buffer_dict() @@ -475,7 +475,7 @@ async def _set_selection( indexer: Indexer, value: NDArrayLike, *, - factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like, + prototype: BufferPrototype, fields: Fields | None = None, ) -> None: # check fields are sensible @@ -497,14 +497,14 @@ async def _set_selection( # We accept any ndarray like object from the user and convert it # to a NDBuffer (or subclass). From this point onwards, we only pass # Buffer and NDBuffer between components. - value_buffer = factory(value) + value_buffer = prototype.nd_buffer.from_ndarray_like(value) # merging with existing data and encoding chunks await self.metadata.codec_pipeline.write( [ ( self.store_path / self.metadata.encode_chunk_key(chunk_coords), - self.metadata.get_chunk_spec(chunk_coords, self.order), + self.metadata.get_chunk_spec(chunk_coords, self.order, prototype), chunk_selection, out_selection, ) @@ -518,14 +518,14 @@ async def setitem( self, selection: Selection, value: NDArrayLike, - factory: Factory.NDArrayLike = NDBuffer.from_ndarray_like, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: indexer = BasicIndexer( selection, shape=self.metadata.shape, chunk_grid=self.metadata.chunk_grid, ) - return await self._set_selection(indexer, value, factory=factory) + return await self._set_selection(indexer, value, prototype=prototype) async def resize( self, new_shape: ChunkCoords, delete_outside_chunks: bool = True @@ -714,7 +714,9 @@ def __setitem__(self, selection: Selection, value: NDArrayLike) -> None: def get_basic_selection( self, selection: BasicSelection = Ellipsis, + *, out: NDBuffer | None = None, + prototype: BufferPrototype = default_buffer_prototype, fields: Fields | None = None, ) -> NDArrayLike: if self.shape == (): @@ -725,57 +727,101 @@ def get_basic_selection( BasicIndexer(selection, self.shape, self.metadata.chunk_grid), out=out, fields=fields, + prototype=prototype, ) ) def set_basic_selection( - self, selection: BasicSelection, value: NDArrayLike, fields: Fields | None = None + self, + selection: BasicSelection, + value: NDArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: indexer = BasicIndexer(selection, self.shape, self.metadata.chunk_grid) - sync(self._async_array._set_selection(indexer, value, fields=fields)) + sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_orthogonal_selection( self, selection: OrthogonalSelection, + *, out: NDBuffer | None = None, fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> NDArrayLike: indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) - return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields)) + return sync( + self._async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + ) def set_orthogonal_selection( - self, selection: OrthogonalSelection, value: NDArrayLike, fields: Fields | None = None + self, + selection: OrthogonalSelection, + value: NDArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: indexer = OrthogonalIndexer(selection, self.shape, self.metadata.chunk_grid) - return sync(self._async_array._set_selection(indexer, value, fields=fields)) + return sync( + self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype) + ) def get_mask_selection( - self, mask: MaskSelection, out: NDBuffer | None = None, fields: Fields | None = None + self, + mask: MaskSelection, + *, + out: NDBuffer | None = None, + fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> NDArrayLike: indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) - return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields)) + return sync( + self._async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + ) def set_mask_selection( - self, mask: MaskSelection, value: NDArrayLike, fields: Fields | None = None + self, + mask: MaskSelection, + value: NDArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: indexer = MaskIndexer(mask, self.shape, self.metadata.chunk_grid) - sync(self._async_array._set_selection(indexer, value, fields=fields)) + sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_coordinate_selection( self, selection: CoordinateSelection, + *, out: NDBuffer | None = None, fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> NDArrayLike: indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) - out_array = sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields)) + out_array = sync( + self._async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + ) # restore shape out_array = out_array.reshape(indexer.sel_shape) return out_array def set_coordinate_selection( - self, selection: CoordinateSelection, value: NDArrayLike, fields: Fields | None = None + self, + selection: CoordinateSelection, + value: NDArrayLike, + *, + fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: # setup indexer indexer = CoordinateIndexer(selection, self.shape, self.metadata.chunk_grid) @@ -792,25 +838,33 @@ def set_coordinate_selection( if hasattr(value, "shape") and len(value.shape) > 1: value = value.reshape(-1) - sync(self._async_array._set_selection(indexer, value, fields=fields)) + sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) def get_block_selection( self, selection: BlockSelection, + *, out: NDBuffer | None = None, fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> NDArrayLike: indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) - return sync(self._async_array._get_selection(indexer=indexer, out=out, fields=fields)) + return sync( + self._async_array._get_selection( + indexer=indexer, out=out, fields=fields, prototype=prototype + ) + ) def set_block_selection( self, selection: BlockSelection, value: NDArrayLike, + *, fields: Fields | None = None, + prototype: BufferPrototype = default_buffer_prototype, ) -> None: indexer = BlockIndexer(selection, self.shape, self.metadata.chunk_grid) - sync(self._async_array._set_selection(indexer, value, fields=fields)) + sync(self._async_array._set_selection(indexer, value, fields=fields, prototype=prototype)) @property def vindex(self) -> VIndex: diff --git a/src/zarr/array_spec.py b/src/zarr/array_spec.py new file mode 100644 index 0000000000..d5717944b4 --- /dev/null +++ b/src/zarr/array_spec.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import numpy as np + +from zarr.buffer import BufferPrototype +from zarr.common import ChunkCoords, parse_dtype, parse_fill_value, parse_order, parse_shapelike + + +@dataclass(frozen=True) +class ArraySpec: + shape: ChunkCoords + dtype: np.dtype[Any] + fill_value: Any + order: Literal["C", "F"] + prototype: BufferPrototype + + def __init__( + self, + shape: ChunkCoords, + dtype: np.dtype[Any], + fill_value: Any, + order: Literal["C", "F"], + prototype: BufferPrototype, + ) -> None: + shape_parsed = parse_shapelike(shape) + dtype_parsed = parse_dtype(dtype) + fill_value_parsed = parse_fill_value(fill_value) + order_parsed = parse_order(order) + + object.__setattr__(self, "shape", shape_parsed) + object.__setattr__(self, "dtype", dtype_parsed) + object.__setattr__(self, "fill_value", fill_value_parsed) + object.__setattr__(self, "order", order_parsed) + object.__setattr__(self, "prototype", prototype) + + @property + def ndim(self) -> int: + return len(self.shape) diff --git a/src/zarr/buffer.py b/src/zarr/buffer.py index 138c7f66d2..1a34d9f290 100644 --- a/src/zarr/buffer.py +++ b/src/zarr/buffer.py @@ -6,6 +6,7 @@ TYPE_CHECKING, Any, Literal, + NamedTuple, Protocol, SupportsIndex, runtime_checkable, @@ -77,7 +78,7 @@ def copy(self) -> Self: ... def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self: ... - def ravel(self, order: Literal["K", "A", "C", "F"] = "C") -> Self: ... + def ravel(self, order: Literal["K", "A", "C", "F"] = ...) -> Self: ... def all(self) -> bool: ... @@ -103,56 +104,6 @@ def check_item_key_is_1d_contiguous(key: Any) -> None: raise ValueError("slice must be contiguous") -class Factory: - class Create(Protocol): - def __call__( - self, - *, - shape: Iterable[int], - dtype: npt.DTypeLike, - order: Literal["C", "F"], - fill_value: Any | None, - ) -> NDBuffer: - """Factory function to create a new NDBuffer (or subclass) - - Callables implementing the `Factory.Create` protocol must create a new - instance of NDBuffer (or subclass) given the following parameters. - - Parameters - ---------- - shape - The shape of the new buffer - dtype - The datatype of each element in the new buffer - order - Whether to store multi-dimensional data in row-major (C-style) or - column-major (Fortran-style) order in memory. - fill_value - If not None, fill the new buffer with a scalar value. - - Return - ------ - A new NDBuffer or subclass instance - """ - - class NDArrayLike(Protocol): - def __call__(self, ndarray_like: NDArrayLike) -> NDBuffer: - """Factory function to coerce an array into a NDBuffer (or subclass) - - Callables implementing the `Factory.NDArrayLike` protocol must return - an instance of NDBuffer (or subclass) given an ndarray-like object. - - Parameters - ---------- - ndarray_like - ndarray-like object - - Return - ------ - A NDBuffer or subclass instance that represents `ndarray_like` - """ - - class Buffer: """A flat contiguous memory block @@ -185,8 +136,8 @@ def __init__(self, array_like: ArrayLike): def create_zero_length(cls) -> Self: """Create an empty buffer with length zero - Return - ------ + Returns + ------- New empty 0-length buffer """ return cls(np.array([], dtype="b")) @@ -200,8 +151,8 @@ def from_array_like(cls, array_like: ArrayLike) -> Self: array_like array-like object that must be 1-dim, contiguous, and byte dtype. - Return - ------ + Returns + ------- New buffer representing `array_like` """ return cls(array_like) @@ -215,46 +166,46 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self: bytes_like bytes-like object - Return - ------ + Returns + ------- New buffer representing `bytes_like` """ return cls.from_array_like(np.frombuffer(bytes_like, dtype="b")) def as_array_like(self) -> ArrayLike: - """Return the underlying array (host or device memory) of this buffer + """Returns the underlying array (host or device memory) of this buffer This will never copy data. - Return - ------ + Returns + ------- The underlying 1d array such as a NumPy or CuPy array. """ return self._data def as_numpy_array(self) -> npt.NDArray[Any]: - """Return the buffer as a NumPy array (host memory). + """Returns the buffer as a NumPy array (host memory). Warning ------- Might have to copy data, consider using `.as_array_like()` instead. - Return - ------ + Returns + ------- NumPy array of this buffer (might be a data copy) """ return np.asanyarray(self._data) def to_bytes(self) -> bytes: - """Return the buffer as `bytes` (host memory). + """Returns the buffer as `bytes` (host memory). Warning ------- Will always copy data, only use this method for small buffers such as metadata buffers. If possible, use `.as_numpy_array()` or `.as_array_like()` instead. - Return - ------ + Returns + ------- `bytes` of this buffer (data copy) """ return bytes(self.as_numpy_array()) @@ -333,8 +284,8 @@ def create( fill_value If not None, fill the new buffer with a scalar value. - Return - ------ + Returns + ------- New buffer representing a new ndarray_like object Developer Notes @@ -356,8 +307,8 @@ def from_ndarray_like(cls, ndarray_like: NDArrayLike) -> Self: ndarray_like ndarray-like object - Return - ------ + Returns + ------- New buffer representing `ndarray_like` """ return cls(ndarray_like) @@ -371,32 +322,32 @@ def from_numpy_array(cls, array_like: npt.ArrayLike) -> Self: array_like Object that can be coerced into a Numpy array - Return - ------ + Returns + ------- New buffer representing `array_like` """ return cls.from_ndarray_like(np.asanyarray(array_like)) def as_ndarray_like(self) -> NDArrayLike: - """Return the underlying array (host or device memory) of this buffer + """Returns the underlying array (host or device memory) of this buffer This will never copy data. - Return - ------ + Returns + ------- The underlying array such as a NumPy or CuPy array. """ return self._data def as_numpy_array(self) -> npt.NDArray[Any]: - """Return the buffer as a NumPy array (host memory). + """Returns the buffer as a NumPy array (host memory). Warning ------- Might have to copy data, consider using `.as_ndarray_like()` instead. - Return - ------ + Returns + ------- NumPy array of this buffer (might be a data copy) """ return np.asanyarray(self._data) @@ -457,7 +408,9 @@ def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Sel return self.__class__(self._data.transpose(axes)) -def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffer) -> Buffer: +def as_numpy_array_wrapper( + func: Callable[[npt.NDArray[Any]], bytes], buf: Buffer, prototype: BufferPrototype +) -> Buffer: """Converts the input of `func` to a numpy array and the output back to `Buffer`. This function is useful when calling a `func` that only support host memory such @@ -473,9 +426,32 @@ def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffe buf The buffer that will be converted to a Numpy array before given as input to `func`. + prototype + The prototype of the output buffer. + + Returns + ------- + The result of `func` converted to a `prototype.buffer` + """ + return prototype.buffer.from_bytes(func(buf.as_numpy_array())) + + +class BufferPrototype(NamedTuple): + """Prototype of the Buffer and NDBuffer class + + The protocol must be pickable. - Return - ------ - The result of `func` converted to a `Buffer` + Attributes + ---------- + buffer + The Buffer class to use when Zarr needs to create new Buffer. + nd_buffer + The NDBuffer class to use when Zarr needs to create new NDBuffer. """ - return Buffer.from_bytes(func(buf.as_numpy_array())) + + buffer: type[Buffer] + nd_buffer: type[NDBuffer] + + +# The default buffer prototype used throughout the Zarr codebase. +default_buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer) diff --git a/src/zarr/codecs/_v2.py b/src/zarr/codecs/_v2.py index c4e4756094..c43a087a94 100644 --- a/src/zarr/codecs/_v2.py +++ b/src/zarr/codecs/_v2.py @@ -6,8 +6,9 @@ from numcodecs.compat import ensure_bytes, ensure_ndarray from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, NDBuffer -from zarr.common import JSON, ArraySpec, to_thread +from zarr.common import JSON, to_thread @dataclass(frozen=True) diff --git a/src/zarr/codecs/blosc.py b/src/zarr/codecs/blosc.py index acba698d94..e577d18fb2 100644 --- a/src/zarr/codecs/blosc.py +++ b/src/zarr/codecs/blosc.py @@ -9,15 +9,14 @@ from numcodecs.blosc import Blosc from zarr.abc.codec import BytesBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper from zarr.codecs.registry import register_codec -from zarr.common import parse_enum, parse_named_configuration, to_thread +from zarr.common import JSON, parse_enum, parse_named_configuration, to_thread if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import JSON, ArraySpec - class BloscShuffle(Enum): noshuffle = "noshuffle" @@ -161,19 +160,23 @@ def _blosc_codec(self) -> Blosc: async def _decode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer: - return await to_thread(as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes) + return await to_thread( + as_numpy_array_wrapper, self._blosc_codec.decode, chunk_bytes, chunk_spec.prototype + ) async def _encode_single( self, chunk_bytes: Buffer, chunk_spec: ArraySpec, ) -> Buffer | None: - # Since blosc only takes bytes, we convert the input and output of the encoding - # between bytes and Buffer + # Since blosc only support host memory, we convert the input and output of the encoding + # between numpy array and buffer return await to_thread( - lambda chunk: Buffer.from_bytes(self._blosc_codec.encode(chunk.as_array_like())), + lambda chunk: chunk_spec.prototype.buffer.from_bytes( + self._blosc_codec.encode(chunk.as_numpy_array()) + ), chunk_bytes, ) diff --git a/src/zarr/codecs/bytes.py b/src/zarr/codecs/bytes.py index f275ae37d1..0b9a5c089e 100644 --- a/src/zarr/codecs/bytes.py +++ b/src/zarr/codecs/bytes.py @@ -8,15 +8,14 @@ import numpy as np from zarr.abc.codec import ArrayBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, NDArrayLike, NDBuffer from zarr.codecs.registry import register_codec -from zarr.common import parse_enum, parse_named_configuration +from zarr.common import JSON, parse_enum, parse_named_configuration if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import JSON, ArraySpec - class Endian(Enum): big = "big" @@ -81,7 +80,9 @@ async def _decode_single( as_nd_array_like = as_array_like else: as_nd_array_like = np.asanyarray(as_array_like) - chunk_array = NDBuffer.from_ndarray_like(as_nd_array_like.view(dtype=dtype)) + chunk_array = chunk_spec.prototype.nd_buffer.from_ndarray_like( + as_nd_array_like.view(dtype=dtype) + ) # ensure correct chunk shape if chunk_array.shape != chunk_spec.shape: @@ -93,7 +94,7 @@ async def _decode_single( async def _encode_single( self, chunk_array: NDBuffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer | None: assert isinstance(chunk_array, NDBuffer) if chunk_array.dtype.itemsize > 1: @@ -103,10 +104,10 @@ async def _encode_single( new_dtype = chunk_array.dtype.newbyteorder(self.endian.name) # type: ignore[arg-type] chunk_array = chunk_array.astype(new_dtype) - as_nd_array_like = chunk_array.as_ndarray_like() - # Flatten the nd-array (only copy if needed) - as_nd_array_like = as_nd_array_like.ravel().view(dtype="b") - return Buffer.from_array_like(as_nd_array_like) + nd_array = chunk_array.as_ndarray_like() + # Flatten the nd-array (only copy if needed) and reinterpret as bytes + nd_array = nd_array.ravel().view(dtype="b") + return chunk_spec.prototype.buffer.from_array_like(nd_array) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length diff --git a/src/zarr/codecs/crc32c_.py b/src/zarr/codecs/crc32c_.py index 724b785d67..b670b25429 100644 --- a/src/zarr/codecs/crc32c_.py +++ b/src/zarr/codecs/crc32c_.py @@ -7,15 +7,14 @@ from crc32c import crc32c from zarr.abc.codec import BytesBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer from zarr.codecs.registry import register_codec -from zarr.common import parse_named_configuration +from zarr.common import JSON, parse_named_configuration if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import JSON, ArraySpec - @dataclass(frozen=True) class Crc32cCodec(BytesBytesCodec): @@ -32,7 +31,7 @@ def to_dict(self) -> dict[str, JSON]: async def _decode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer: data = chunk_bytes.as_numpy_array() crc32_bytes = data[-4:] @@ -44,18 +43,18 @@ async def _decode_single( raise ValueError( f"Stored and computed checksum do not match. Stored: {stored_checksum!r}. Computed: {computed_checksum!r}." ) - return Buffer.from_array_like(inner_bytes) + return chunk_spec.prototype.buffer.from_array_like(inner_bytes) async def _encode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer | None: data = chunk_bytes.as_numpy_array() # Calculate the checksum and "cast" it to a numpy array checksum = np.array([crc32c(data)], dtype=np.uint32) # Append the checksum (as bytes) to the data - return Buffer.from_array_like(np.append(data, checksum.view("b"))) + return chunk_spec.prototype.buffer.from_array_like(np.append(data, checksum.view("b"))) def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int: return input_byte_length + 4 diff --git a/src/zarr/codecs/gzip.py b/src/zarr/codecs/gzip.py index 6a8aaf08bb..0ad97c1207 100644 --- a/src/zarr/codecs/gzip.py +++ b/src/zarr/codecs/gzip.py @@ -6,15 +6,14 @@ from numcodecs.gzip import GZip from zarr.abc.codec import BytesBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper from zarr.codecs.registry import register_codec -from zarr.common import parse_named_configuration, to_thread +from zarr.common import JSON, parse_named_configuration, to_thread if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import JSON, ArraySpec - def parse_gzip_level(data: JSON) -> int: if not isinstance(data, (int)): @@ -48,16 +47,20 @@ def to_dict(self) -> dict[str, JSON]: async def _decode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer: - return await to_thread(as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes) + return await to_thread( + as_numpy_array_wrapper, GZip(self.level).decode, chunk_bytes, chunk_spec.prototype + ) async def _encode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer | None: - return await to_thread(as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes) + return await to_thread( + as_numpy_array_wrapper, GZip(self.level).encode, chunk_bytes, chunk_spec.prototype + ) def compute_encoded_size( self, diff --git a/src/zarr/codecs/pipeline.py b/src/zarr/codecs/pipeline.py index ada4ae23f9..acef311a8c 100644 --- a/src/zarr/codecs/pipeline.py +++ b/src/zarr/codecs/pipeline.py @@ -16,7 +16,7 @@ CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteSetter -from zarr.buffer import Buffer, NDBuffer +from zarr.buffer import Buffer, BufferPrototype, NDBuffer from zarr.codecs.registry import get_codec_class from zarr.common import JSON, concurrent_map, parse_named_configuration from zarr.config import config @@ -26,7 +26,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import ArraySpec + from zarr.array_spec import ArraySpec T = TypeVar("T") U = TypeVar("U") @@ -310,8 +310,11 @@ async def read_batch( out[out_selection] = chunk_spec.fill_value else: chunk_bytes_batch = await concurrent_map( - [(byte_getter,) for byte_getter, _, _, _ in batch_info], - lambda byte_getter: byte_getter.get(), + [ + (byte_getter, array_spec.prototype) + for byte_getter, array_spec, _, _ in batch_info + ], + lambda byte_getter, prototype: byte_getter.get(prototype), config.get("async.concurrency"), ) chunk_array_batch = await self.decode_batch( @@ -345,7 +348,7 @@ def _merge_chunk_array( if is_total_slice(chunk_selection, chunk_spec.shape) and value.shape == chunk_spec.shape: return value if existing_chunk_array is None: - chunk_array = NDBuffer.create( + chunk_array = chunk_spec.prototype.nd_buffer.create( shape=chunk_spec.shape, dtype=chunk_spec.dtype, order=chunk_spec.order, @@ -387,15 +390,20 @@ async def write_batch( else: # Read existing bytes if not total slice - async def _read_key(byte_setter: ByteSetter | None) -> Buffer | None: + async def _read_key( + byte_setter: ByteSetter | None, prototype: BufferPrototype + ) -> Buffer | None: if byte_setter is None: return None - return await byte_setter.get() + return await byte_setter.get(prototype=prototype) chunk_bytes_batch: Iterable[Buffer | None] chunk_bytes_batch = await concurrent_map( [ - (None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter,) + ( + None if is_total_slice(chunk_selection, chunk_spec.shape) else byte_setter, + chunk_spec.prototype, + ) for byte_setter, chunk_spec, chunk_selection, _ in batch_info ], _read_key, diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index dab2810f35..74ad5ac44f 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -18,14 +18,14 @@ CodecPipeline, ) from zarr.abc.store import ByteGetter, ByteSetter -from zarr.buffer import Buffer, NDBuffer +from zarr.array_spec import ArraySpec +from zarr.buffer import Buffer, BufferPrototype, NDBuffer, default_buffer_prototype from zarr.chunk_grids import RegularChunkGrid from zarr.codecs.bytes import BytesCodec from zarr.codecs.crc32c_ import Crc32cCodec from zarr.codecs.pipeline import BatchedCodecPipeline from zarr.codecs.registry import register_codec from zarr.common import ( - ArraySpec, ChunkCoords, ChunkCoordsLike, parse_enum, @@ -62,8 +62,13 @@ class _ShardingByteGetter(ByteGetter): shard_dict: ShardMapping chunk_coords: ChunkCoords - async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: + async def get( + self, prototype: BufferPrototype, byte_range: tuple[int, int | None] | None = None + ) -> Buffer | None: assert byte_range is None, "byte_range is not supported within shards" + assert ( + prototype is default_buffer_prototype + ), "prototype is not supported within shards currently" return self.shard_dict.get(self.chunk_coords) @@ -391,7 +396,7 @@ async def _decode_single( ) # setup output array - out = NDBuffer.create( + out = chunk_spec.prototype.nd_buffer.create( shape=shard_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 ) shard_dict = await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) @@ -434,7 +439,7 @@ async def _decode_partial_single( ) # setup output array - out = NDBuffer.create( + out = shard_spec.prototype.nd_buffer.create( shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 ) @@ -445,7 +450,11 @@ async def _decode_partial_single( shard_dict: ShardMapping = {} if self._is_total_shard(all_chunk_coords, chunks_per_shard): # read entire shard - shard_dict_maybe = await self._load_full_shard_maybe(byte_getter, chunks_per_shard) + shard_dict_maybe = await self._load_full_shard_maybe( + byte_getter=byte_getter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, + ) if shard_dict_maybe is None: return None shard_dict = shard_dict_maybe @@ -458,7 +467,9 @@ async def _decode_partial_single( for chunk_coords in all_chunk_coords: chunk_byte_slice = shard_index.get_chunk_slice(chunk_coords) if chunk_byte_slice: - chunk_bytes = await byte_getter.get(chunk_byte_slice) + chunk_bytes = await byte_getter.get( + prototype=chunk_spec.prototype, byte_range=chunk_byte_slice + ) if chunk_bytes: shard_dict[chunk_coords] = chunk_bytes @@ -525,7 +536,11 @@ async def _encode_partial_single( chunk_spec = self._get_chunk_spec(shard_spec) shard_dict = _MergingShardBuilder( - await self._load_full_shard_maybe(byte_setter, chunks_per_shard) + await self._load_full_shard_maybe( + byte_getter=byte_setter, + prototype=chunk_spec.prototype, + chunks_per_shard=chunks_per_shard, + ) or _ShardReader.create_empty(chunks_per_shard), _ShardBuilder.create_empty(chunks_per_shard), ) @@ -607,6 +622,7 @@ def _get_index_chunk_spec(self, chunks_per_shard: ChunkCoords) -> ArraySpec: dtype=np.dtype(" ArraySpec: @@ -615,6 +631,7 @@ def _get_chunk_spec(self, shard_spec: ArraySpec) -> ArraySpec: dtype=shard_spec.dtype, fill_value=shard_spec.fill_value, order=shard_spec.order, + prototype=shard_spec.prototype, ) def _get_chunks_per_shard(self, shard_spec: ArraySpec) -> ChunkCoords: @@ -632,9 +649,13 @@ async def _load_shard_index_maybe( ) -> _ShardIndex | None: shard_index_size = self._shard_index_size(chunks_per_shard) if self.index_location == ShardingCodecIndexLocation.start: - index_bytes = await byte_getter.get((0, shard_index_size)) + index_bytes = await byte_getter.get( + prototype=default_buffer_prototype, byte_range=(0, shard_index_size) + ) else: - index_bytes = await byte_getter.get((-shard_index_size, None)) + index_bytes = await byte_getter.get( + prototype=default_buffer_prototype, byte_range=(-shard_index_size, None) + ) if index_bytes is not None: return await self._decode_shard_index(index_bytes, chunks_per_shard) return None @@ -647,9 +668,9 @@ async def _load_shard_index( ) or _ShardIndex.create_empty(chunks_per_shard) async def _load_full_shard_maybe( - self, byte_getter: ByteGetter, chunks_per_shard: ChunkCoords + self, byte_getter: ByteGetter, prototype: BufferPrototype, chunks_per_shard: ChunkCoords ) -> _ShardReader | None: - shard_bytes = await byte_getter.get() + shard_bytes = await byte_getter.get(prototype=prototype) return ( await _ShardReader.from_bytes(shard_bytes, self, chunks_per_shard) diff --git a/src/zarr/codecs/transpose.py b/src/zarr/codecs/transpose.py index 9fcee4e66b..33dab21fb6 100644 --- a/src/zarr/codecs/transpose.py +++ b/src/zarr/codecs/transpose.py @@ -7,9 +7,10 @@ import numpy as np from zarr.abc.codec import ArrayArrayCodec +from zarr.array_spec import ArraySpec from zarr.buffer import NDBuffer from zarr.codecs.registry import register_codec -from zarr.common import JSON, ArraySpec, ChunkCoordsLike, parse_named_configuration +from zarr.common import JSON, ChunkCoordsLike, parse_named_configuration if TYPE_CHECKING: from typing import TYPE_CHECKING @@ -64,13 +65,12 @@ def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self: return self def resolve_metadata(self, chunk_spec: ArraySpec) -> ArraySpec: - from zarr.common import ArraySpec - return ArraySpec( shape=tuple(chunk_spec.shape[self.order[i]] for i in range(chunk_spec.ndim)), dtype=chunk_spec.dtype, fill_value=chunk_spec.fill_value, order=chunk_spec.order, + prototype=chunk_spec.prototype, ) async def _decode_single( @@ -85,7 +85,7 @@ async def _decode_single( async def _encode_single( self, chunk_array: NDBuffer, - chunk_spec: ArraySpec, + _chunk_spec: ArraySpec, ) -> NDBuffer | None: chunk_array = chunk_array.transpose(self.order) return chunk_array diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index 451fae8b37..4c5afba00b 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -7,15 +7,14 @@ from zstandard import ZstdCompressor, ZstdDecompressor from zarr.abc.codec import BytesBytesCodec +from zarr.array_spec import ArraySpec from zarr.buffer import Buffer, as_numpy_array_wrapper from zarr.codecs.registry import register_codec -from zarr.common import parse_named_configuration, to_thread +from zarr.common import JSON, parse_named_configuration, to_thread if TYPE_CHECKING: from typing_extensions import Self - from zarr.common import JSON, ArraySpec - def parse_zstd_level(data: JSON) -> int: if isinstance(data, int): @@ -64,16 +63,20 @@ def _decompress(self, data: npt.NDArray[Any]) -> bytes: async def _decode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer: - return await to_thread(as_numpy_array_wrapper, self._decompress, chunk_bytes) + return await to_thread( + as_numpy_array_wrapper, self._decompress, chunk_bytes, chunk_spec.prototype + ) async def _encode_single( self, chunk_bytes: Buffer, - _chunk_spec: ArraySpec, + chunk_spec: ArraySpec, ) -> Buffer | None: - return await to_thread(as_numpy_array_wrapper, self._compress, chunk_bytes) + return await to_thread( + as_numpy_array_wrapper, self._compress, chunk_bytes, chunk_spec.prototype + ) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: raise NotImplementedError diff --git a/src/zarr/common.py b/src/zarr/common.py index ec5d870f92..bca9f171af 100644 --- a/src/zarr/common.py +++ b/src/zarr/common.py @@ -5,7 +5,6 @@ import functools import operator from collections.abc import Iterable -from dataclasses import dataclass from enum import Enum from typing import ( TYPE_CHECKING, @@ -91,31 +90,6 @@ def parse_enum(data: JSON, cls: type[E]) -> E: raise ValueError(f"Value must be one of {list(enum_names(cls))!r}. Got {data} instead.") -@dataclass(frozen=True) -class ArraySpec: - shape: ChunkCoords - dtype: np.dtype[Any] - fill_value: Any - order: Literal["C", "F"] - - def __init__( - self, shape: ChunkCoords, dtype: np.dtype[Any], fill_value: Any, order: Literal["C", "F"] - ) -> None: - shape_parsed = parse_shapelike(shape) - dtype_parsed = parse_dtype(dtype) - fill_value_parsed = parse_fill_value(fill_value) - order_parsed = parse_order(order) - - object.__setattr__(self, "shape", shape_parsed) - object.__setattr__(self, "dtype", dtype_parsed) - object.__setattr__(self, "fill_value", fill_value_parsed) - object.__setattr__(self, "order", order_parsed) - - @property - def ndim(self) -> int: - return len(self.shape) - - def parse_name(data: JSON, expected: str | None = None) -> str: if isinstance(data, str): if expected is None or data == expected: diff --git a/src/zarr/metadata.py b/src/zarr/metadata.py index ca8cf1cdd2..bcb70bd4b2 100644 --- a/src/zarr/metadata.py +++ b/src/zarr/metadata.py @@ -5,31 +5,29 @@ from collections.abc import Iterable from dataclasses import dataclass, field, replace from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np import numpy.typing as npt from zarr.abc.codec import Codec, CodecPipeline from zarr.abc.metadata import Metadata -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype from zarr.chunk_grids import ChunkGrid, RegularChunkGrid from zarr.chunk_key_encodings import ChunkKeyEncoding, parse_separator from zarr.codecs._v2 import V2Compressor, V2Filters if TYPE_CHECKING: - from typing import Literal - from typing_extensions import Self import numcodecs.abc +from zarr.array_spec import ArraySpec from zarr.common import ( JSON, ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, - ArraySpec, ChunkCoords, parse_dtype, parse_fill_value, @@ -137,7 +135,9 @@ def codec_pipeline(self) -> CodecPipeline: pass @abstractmethod - def get_chunk_spec(self, _chunk_coords: ChunkCoords, order: Literal["C", "F"]) -> ArraySpec: + def get_chunk_spec( + self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype + ) -> ArraySpec: pass @abstractmethod @@ -198,6 +198,7 @@ def __init__( dtype=data_type_parsed, fill_value=fill_value_parsed, order="C", # TODO: order is not needed here. + prototype=default_buffer_prototype, # TODO: prototype is not needed here. ) codecs_parsed = parse_codecs(codecs).evolve_from_array_spec(array_spec) @@ -239,7 +240,9 @@ def ndim(self) -> int: def codec_pipeline(self) -> CodecPipeline: return self.codecs - def get_chunk_spec(self, _chunk_coords: ChunkCoords, order: Literal["C", "F"]) -> ArraySpec: + def get_chunk_spec( + self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype + ) -> ArraySpec: assert isinstance( self.chunk_grid, RegularChunkGrid ), "Currently, only regular chunk grid is supported" @@ -248,6 +251,7 @@ def get_chunk_spec(self, _chunk_coords: ChunkCoords, order: Literal["C", "F"]) - dtype=self.dtype, fill_value=self.fill_value, order=order, + prototype=prototype, ) def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: @@ -412,12 +416,15 @@ def to_dict(self) -> JSON: return zarray_dict - def get_chunk_spec(self, _chunk_coords: ChunkCoords, order: Literal["C", "F"]) -> ArraySpec: + def get_chunk_spec( + self, _chunk_coords: ChunkCoords, order: Literal["C", "F"], prototype: BufferPrototype + ) -> ArraySpec: return ArraySpec( shape=self.chunk_grid.chunk_shape, dtype=self.dtype, fill_value=self.fill_value, order=order, + prototype=prototype, ) def encode_chunk_key(self, chunk_coords: ChunkCoords) -> str: diff --git a/src/zarr/store/core.py b/src/zarr/store/core.py index abb08291df..70c39db1b7 100644 --- a/src/zarr/store/core.py +++ b/src/zarr/store/core.py @@ -4,7 +4,7 @@ from typing import Any from zarr.abc.store import Store -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype from zarr.common import OpenMode from zarr.store.local import LocalStore @@ -26,8 +26,12 @@ def __init__(self, store: Store, path: str | None = None): self.store = store self.path = path or "" - async def get(self, byte_range: tuple[int, int | None] | None = None) -> Buffer | None: - return await self.store.get(self.path, byte_range) + async def get( + self, + prototype: BufferPrototype = default_buffer_prototype, + byte_range: tuple[int, int | None] | None = None, + ) -> Buffer | None: + return await self.store.get(self.path, prototype=prototype, byte_range=byte_range) async def set(self, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: if byte_range is not None: diff --git a/src/zarr/store/local.py b/src/zarr/store/local.py index 945c6160ad..9238700445 100644 --- a/src/zarr/store/local.py +++ b/src/zarr/store/local.py @@ -6,11 +6,13 @@ from pathlib import Path from zarr.abc.store import Store -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype from zarr.common import OpenMode, concurrent_map, to_thread -def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer: +def _get( + path: Path, prototype: BufferPrototype, byte_range: tuple[int | None, int | None] | None +) -> Buffer: """ Fetch a contiguous region of bytes from a file. @@ -32,7 +34,7 @@ def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer end = (start + byte_range[1]) if byte_range[1] is not None else None else: - return Buffer.from_bytes(path.read_bytes()) + return prototype.buffer.from_bytes(path.read_bytes()) with path.open("rb") as f: size = f.seek(0, io.SEEK_END) if start is not None: @@ -43,8 +45,8 @@ def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer if end is not None: if end < 0: end = size + end - return Buffer.from_bytes(f.read(end - f.tell())) - return Buffer.from_bytes(f.read()) + return prototype.buffer.from_bytes(f.read(end - f.tell())) + return prototype.buffer.from_bytes(f.read()) def _put( @@ -87,18 +89,23 @@ def __eq__(self, other: object) -> bool: return isinstance(other, type(self)) and self.root == other.root async def get( - self, key: str, byte_range: tuple[int | None, int | None] | None = None + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: assert isinstance(key, str) path = self.root / key try: - return await to_thread(_get, path, byte_range) + return await to_thread(_get, path, prototype, byte_range) except (FileNotFoundError, IsADirectoryError, NotADirectoryError): return None async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] + self, + prototype: BufferPrototype, + key_ranges: list[tuple[str, tuple[int | None, int | None]]], ) -> list[Buffer | None]: """ Read byte ranges from multiple keys. @@ -114,15 +121,12 @@ async def get_partial_values( for key, byte_range in key_ranges: assert isinstance(key, str) path = self.root / key - args.append((_get, path, byte_range)) + args.append((_get, path, prototype, byte_range)) return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit async def set(self, key: str, value: Buffer) -> None: self._check_writable() assert isinstance(key, str) - if isinstance(value, bytes | bytearray): # type:ignore[unreachable] - # TODO: to support the v2 tests, we convert bytes to Buffer here - value = Buffer.from_bytes(value) # type:ignore[unreachable] if not isinstance(value, Buffer): raise TypeError("LocalStore.set(): `value` must a Buffer instance") path = self.root / key diff --git a/src/zarr/store/memory.py b/src/zarr/store/memory.py index fd6fadd3ee..d75e8c348c 100644 --- a/src/zarr/store/memory.py +++ b/src/zarr/store/memory.py @@ -3,7 +3,7 @@ from collections.abc import AsyncGenerator, MutableMapping from zarr.abc.store import Store -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype from zarr.common import OpenMode, concurrent_map from zarr.store.core import _normalize_interval_index @@ -30,7 +30,10 @@ def __repr__(self) -> str: return f"MemoryStore({str(self)!r})" async def get( - self, key: str, byte_range: tuple[int | None, int | None] | None = None + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: assert isinstance(key, str) try: @@ -41,9 +44,15 @@ async def get( return None async def get_partial_values( - self, key_ranges: list[tuple[str, tuple[int | None, int | None]]] + self, + prototype: BufferPrototype, + key_ranges: list[tuple[str, tuple[int | None, int | None]]], ) -> list[Buffer | None]: - vals = await concurrent_map(key_ranges, self.get, limit=None) + # All the key-ranges arguments goes with the same prototype + async def _get(key: str, byte_range: tuple[int, int | None]) -> Buffer | None: + return await self.get(key, prototype=prototype, byte_range=byte_range) + + vals = await concurrent_map(key_ranges, _get, limit=None) return vals async def exists(self, key: str) -> bool: @@ -52,9 +61,6 @@ async def exists(self, key: str) -> bool: async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: self._check_writable() assert isinstance(key, str) - if isinstance(value, bytes | bytearray): # type:ignore[unreachable] - # TODO: to support the v2 tests, we convert bytes to Buffer here - value = Buffer.from_bytes(value) # type:ignore[unreachable] if not isinstance(value, Buffer): raise TypeError(f"Expected Buffer. Got {type(value)}.") diff --git a/src/zarr/store/remote.py b/src/zarr/store/remote.py index 60217fb72c..3eb057f9b8 100644 --- a/src/zarr/store/remote.py +++ b/src/zarr/store/remote.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any from zarr.abc.store import Store -from zarr.buffer import Buffer +from zarr.buffer import Buffer, BufferPrototype from zarr.common import OpenMode from zarr.store.core import _dereference_path @@ -55,7 +55,10 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]: return fs, root async def get( - self, key: str, byte_range: tuple[int | None, int | None] | None = None + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int | None, int | None] | None = None, ) -> Buffer | None: assert isinstance(key, str) fs, root = self._make_fs() diff --git a/src/zarr/testing/store.py b/src/zarr/testing/store.py index cb4dc9f7b5..5929f47049 100644 --- a/src/zarr/testing/store.py +++ b/src/zarr/testing/store.py @@ -3,7 +3,7 @@ import pytest from zarr.abc.store import Store -from zarr.buffer import Buffer +from zarr.buffer import Buffer, default_buffer_prototype from zarr.store.core import _normalize_interval_index from zarr.testing.utils import assert_bytes_equal @@ -91,7 +91,7 @@ async def test_get( """ data_buf = Buffer.from_bytes(data) self.set(store, key, data_buf) - observed = await store.get(key, byte_range=byte_range) + observed = await store.get(key, prototype=default_buffer_prototype, byte_range=byte_range) start, length = _normalize_interval_index(data_buf, interval=byte_range) expected = data_buf[start : start + length] assert_bytes_equal(observed, expected) @@ -125,7 +125,9 @@ async def test_get_partial_values( self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8"))) # read back just part of it - observed_maybe = await store.get_partial_values(key_ranges=key_ranges) + observed_maybe = await store.get_partial_values( + prototype=default_buffer_prototype, key_ranges=key_ranges + ) observed: list[Buffer] = [] expected: list[Buffer] = [] @@ -136,7 +138,7 @@ async def test_get_partial_values( for idx in range(len(observed)): key, byte_range = key_ranges[idx] - result = await store.get(key, byte_range=byte_range) + result = await store.get(key, prototype=default_buffer_prototype, byte_range=byte_range) assert result is not None expected.append(result) diff --git a/tests/v3/package_with_entrypoint/__init__.py b/tests/v3/package_with_entrypoint/__init__.py index b8bf903c01..6368e5b236 100644 --- a/tests/v3/package_with_entrypoint/__init__.py +++ b/tests/v3/package_with_entrypoint/__init__.py @@ -1,7 +1,8 @@ from numpy import ndarray from zarr.abc.codec import ArrayBytesCodec -from zarr.common import ArraySpec, BytesLike +from zarr.array_spec import ArraySpec +from zarr.common import BytesLike class TestCodec(ArrayBytesCodec): diff --git a/tests/v3/test_buffer.py b/tests/v3/test_buffer.py index 2f58d116fe..e814afef15 100644 --- a/tests/v3/test_buffer.py +++ b/tests/v3/test_buffer.py @@ -8,7 +8,15 @@ import pytest from zarr.array import AsyncArray -from zarr.buffer import ArrayLike, NDArrayLike, NDBuffer +from zarr.buffer import ArrayLike, Buffer, BufferPrototype, NDArrayLike, NDBuffer +from zarr.codecs.blosc import BloscCodec +from zarr.codecs.bytes import BytesCodec +from zarr.codecs.crc32c_ import Crc32cCodec +from zarr.codecs.gzip import GzipCodec +from zarr.codecs.transpose import TransposeCodec +from zarr.codecs.zstd import ZstdCodec +from zarr.store.core import StorePath +from zarr.store.memory import MemoryStore if TYPE_CHECKING: from typing_extensions import Self @@ -17,7 +25,9 @@ class MyNDArrayLike(np.ndarray): """An example of a ndarray-like class""" - pass + +class MyBuffer(Buffer): + """Example of a custom Buffer that handles ArrayLike""" class MyNDBuffer(NDBuffer): @@ -39,6 +49,28 @@ def create( return ret +class MyStore(MemoryStore): + """Example of a custom Store that expect MyBuffer for all its non-metadata + + We assume that keys containing "json" is metadata + """ + + async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: + if "json" not in key: + assert isinstance(value, MyBuffer) + await super().set(key, value, byte_range) + + async def get( + self, + key: str, + prototype: BufferPrototype, + byte_range: tuple[int, int | None] | None = None, + ) -> Buffer | None: + if "json" not in key: + assert prototype.buffer is MyBuffer + return await super().get(key, byte_range) + + def test_nd_array_like(xp): ary = xp.arange(10) assert isinstance(ary, ArrayLike) @@ -46,10 +78,12 @@ def test_nd_array_like(xp): @pytest.mark.asyncio -async def test_async_array_factory(store_path): +async def test_async_array_prototype(): + """Test the use of a custom buffer prototype""" + expect = np.zeros((9, 9), dtype="uint16", order="F") a = await AsyncArray.create( - store_path, + StorePath(MyStore(mode="w")) / "test_async_array_prototype", shape=expect.shape, chunk_shape=(5, 5), dtype=expect.dtype, @@ -57,11 +91,45 @@ async def test_async_array_factory(store_path): ) expect[1:4, 3:6] = np.ones((3, 3)) + my_prototype = BufferPrototype(buffer=MyBuffer, nd_buffer=MyNDBuffer) + await a.setitem( selection=(slice(1, 4), slice(3, 6)), value=np.ones((3, 3)), - factory=MyNDBuffer.from_ndarray_like, + prototype=my_prototype, + ) + got = await a.getitem(selection=(slice(0, 9), slice(0, 9)), prototype=my_prototype) + assert isinstance(got, MyNDArrayLike) + assert np.array_equal(expect, got) + + +@pytest.mark.asyncio +async def test_codecs_use_of_prototype(): + expect = np.zeros((10, 10), dtype="uint16", order="F") + a = await AsyncArray.create( + StorePath(MyStore(mode="w")) / "test_codecs_use_of_prototype", + shape=expect.shape, + chunk_shape=(5, 5), + dtype=expect.dtype, + fill_value=0, + codecs=[ + TransposeCodec(order=(1, 0)), + BytesCodec(), + BloscCodec(), + Crc32cCodec(), + GzipCodec(), + ZstdCodec(), + ], + ) + expect[:] = np.arange(100).reshape(10, 10) + + my_prototype = BufferPrototype(buffer=MyBuffer, nd_buffer=MyNDBuffer) + + await a.setitem( + selection=(slice(0, 10), slice(0, 10)), + value=expect[:], + prototype=my_prototype, ) - got = await a.getitem(selection=(slice(0, 9), slice(0, 9)), factory=MyNDBuffer.create) + got = await a.getitem(selection=(slice(0, 10), slice(0, 10)), prototype=my_prototype) assert isinstance(got, MyNDArrayLike) assert np.array_equal(expect, got) diff --git a/tests/v3/test_group.py b/tests/v3/test_group.py index c529e2491f..e11af748b3 100644 --- a/tests/v3/test_group.py +++ b/tests/v3/test_group.py @@ -49,7 +49,7 @@ def test_group_children(store: MemoryStore | LocalStore) -> None: # add an extra object under a directory-like prefix in the domain of the group. # this creates a directory with a random key in it # this should not show up as a member - sync(store.set(f"{path}/extra_directory/extra_object-2", b"000000")) + sync(store.set(f"{path}/extra_directory/extra_object-2", Buffer.from_bytes(b"000000"))) members_observed = group.members # members are not guaranteed to be ordered, so sort before comparing assert sorted(dict(members_observed)) == sorted(members_expected) diff --git a/tests/v3/test_indexing.py b/tests/v3/test_indexing.py index 9ce485945b..00ea947b49 100644 --- a/tests/v3/test_indexing.py +++ b/tests/v3/test_indexing.py @@ -12,7 +12,7 @@ import zarr from zarr.abc.store import Store -from zarr.buffer import NDBuffer +from zarr.buffer import BufferPrototype, NDBuffer from zarr.common import ChunkCoords from zarr.indexing import ( make_slice_selection, @@ -51,10 +51,10 @@ def __init__(self): super().__init__(mode="w") self.counter = Counter() - async def get(self, key, byte_range=None): + async def get(self, key, prototype: BufferPrototype, byte_range=None): key_suffix = "/".join(key.split("/")[1:]) self.counter["__getitem__", key_suffix] += 1 - return await super().get(key, byte_range) + return await super().get(key, prototype, byte_range) async def set(self, key, value, byte_range=None): key_suffix = "/".join(key.split("/")[1:]) @@ -225,7 +225,6 @@ def test_get_basic_selection_0d(store: StorePath): def _test_get_basic_selection(a, z, selection): - print(a, z, selection) expect = a[selection] actual = z.get_basic_selection(selection) assert_array_equal(expect, actual)