Skip to content

Commit b1f4c50

Browse files
authored
Protocols for Buffer and NDBuffer (#1899)
1 parent 4da9505 commit b1f4c50

File tree

9 files changed

+151
-82
lines changed

9 files changed

+151
-82
lines changed

src/zarr/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -582,12 +582,12 @@ def store_path(self) -> StorePath:
582582
def order(self) -> Literal["C", "F"]:
583583
return self._async_array.order
584584

585-
def __getitem__(self, selection: Selection) -> npt.NDArray[Any]:
585+
def __getitem__(self, selection: Selection) -> NDArrayLike:
586586
return sync(
587587
self._async_array.getitem(selection),
588588
)
589589

590-
def __setitem__(self, selection: Selection, value: npt.NDArray[Any]) -> None:
590+
def __setitem__(self, selection: Selection, value: NDArrayLike) -> None:
591591
sync(
592592
self._async_array.setitem(selection, value),
593593
)

src/zarr/buffer.py

Lines changed: 80 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,94 @@
11
from __future__ import annotations
22

33
import sys
4-
from collections.abc import Callable, Iterable
4+
from collections.abc import Callable, Iterable, Sequence
55
from typing import (
66
TYPE_CHECKING,
77
Any,
88
Literal,
99
Protocol,
10-
TypeAlias,
10+
SupportsIndex,
11+
runtime_checkable,
1112
)
1213

1314
import numpy as np
1415
import numpy.typing as npt
1516

17+
from zarr.common import ChunkCoords
18+
1619
if TYPE_CHECKING:
1720
from typing_extensions import Self
1821

1922
from zarr.codecs.bytes import Endian
2023
from zarr.common import BytesLike
2124

22-
# TODO: create a protocol for the attributes we need, for now we alias Numpy's ndarray
23-
# both for the array-like and ndarray-like
24-
ArrayLike: TypeAlias = npt.NDArray[Any]
25-
NDArrayLike: TypeAlias = npt.NDArray[Any]
25+
26+
@runtime_checkable
27+
class ArrayLike(Protocol):
28+
"""Protocol for the array-like type that underlie Buffer"""
29+
30+
@property
31+
def dtype(self) -> np.dtype[Any]: ...
32+
33+
@property
34+
def ndim(self) -> int: ...
35+
36+
@property
37+
def size(self) -> int: ...
38+
39+
def __getitem__(self, key: slice) -> Self: ...
40+
41+
def __setitem__(self, key: slice, value: Any) -> None: ...
42+
43+
44+
@runtime_checkable
45+
class NDArrayLike(Protocol):
46+
"""Protocol for the nd-array-like type that underlie NDBuffer"""
47+
48+
@property
49+
def dtype(self) -> np.dtype[Any]: ...
50+
51+
@property
52+
def ndim(self) -> int: ...
53+
54+
@property
55+
def size(self) -> int: ...
56+
57+
@property
58+
def shape(self) -> ChunkCoords: ...
59+
60+
def __len__(self) -> int: ...
61+
62+
def __getitem__(self, key: slice) -> Self: ...
63+
64+
def __setitem__(self, key: slice, value: Any) -> None: ...
65+
66+
def reshape(self, shape: ChunkCoords, *, order: Literal["A", "C", "F"] = ...) -> Self: ...
67+
68+
def view(self, dtype: npt.DTypeLike) -> Self: ...
69+
70+
def astype(self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = ...) -> Self: ...
71+
72+
def fill(self, value: Any) -> None: ...
73+
74+
def copy(self) -> Self: ...
75+
76+
def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self: ...
77+
78+
def ravel(self, order: Literal["K", "A", "C", "F"] = "C") -> Self: ...
79+
80+
def all(self) -> bool: ...
81+
82+
def __eq__(self, other: Any) -> Self: # type: ignore
83+
"""Element-wise equal
84+
85+
Notice
86+
------
87+
Type checkers such as mypy complains because the return type isn't a bool like
88+
its supertype "object", which violates the Liskov substitution principle.
89+
This is true, but since NumPy's ndarray is defined as an element-wise equal,
90+
our hands are tied.
91+
"""
2692

2793

2894
def check_item_key_is_1d_contiguous(key: Any) -> None:
@@ -124,7 +190,7 @@ def create_zero_length(cls) -> Self:
124190
return cls(np.array([], dtype="b"))
125191

126192
@classmethod
127-
def from_array_like(cls, array_like: NDArrayLike) -> Self:
193+
def from_array_like(cls, array_like: ArrayLike) -> Self:
128194
"""Create a new buffer of a array-like object
129195
130196
Parameters
@@ -153,7 +219,7 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self:
153219
"""
154220
return cls.from_array_like(np.frombuffer(bytes_like, dtype="b"))
155221

156-
def as_array_like(self) -> NDArrayLike:
222+
def as_array_like(self) -> ArrayLike:
157223
"""Return the underlying array (host or device memory) of this buffer
158224
159225
This will never copy data.
@@ -164,22 +230,6 @@ def as_array_like(self) -> NDArrayLike:
164230
"""
165231
return self._data
166232

167-
def as_nd_buffer(self, *, dtype: npt.DTypeLike) -> NDBuffer:
168-
"""Create a new NDBuffer from this one.
169-
170-
This will never copy data.
171-
172-
Parameters
173-
----------
174-
dtype
175-
The datatype of the returned buffer (reinterpretation of the bytes)
176-
177-
Return
178-
------
179-
New NDbuffer representing `self.as_array_like()`
180-
"""
181-
return NDBuffer.from_ndarray_like(self._data.view(dtype=dtype))
182-
183233
def as_numpy_array(self) -> npt.NDArray[Any]:
184234
"""Return the buffer as a NumPy array (host memory).
185235
@@ -223,17 +273,8 @@ def __add__(self, other: Buffer) -> Self:
223273

224274
other_array = other.as_array_like()
225275
assert other_array.dtype == np.dtype("b")
226-
return self.__class__(np.concatenate((self._data, other_array)))
227-
228-
def __eq__(self, other: Any) -> bool:
229-
if isinstance(other, bytes | bytearray):
230-
# Many of the tests compares `Buffer` with `bytes` so we
231-
# convert the bytes to a Buffer and try again
232-
return self == self.from_bytes(other)
233-
if isinstance(other, Buffer):
234-
return (self._data == other.as_array_like()).all()
235-
raise ValueError(
236-
f"equal operator not supported between {self.__class__} and {other.__class__}"
276+
return self.__class__(
277+
np.concatenate((np.asanyarray(self._data), np.asanyarray(other_array)))
237278
)
238279

239280

@@ -345,22 +386,6 @@ def as_ndarray_like(self) -> NDArrayLike:
345386
"""
346387
return self._data
347388

348-
def as_buffer(self) -> Buffer:
349-
"""Create a new Buffer from this one.
350-
351-
Warning
352-
-------
353-
Copies data if the buffer is non-contiguous.
354-
355-
Return
356-
------
357-
The new buffer (might be data copy)
358-
"""
359-
data = self._data
360-
if not self._data.flags.contiguous:
361-
data = np.ascontiguousarray(self._data)
362-
return Buffer(data.reshape(-1).view(dtype="b")) # Flatten the array without copy
363-
364389
def as_numpy_array(self) -> npt.NDArray[Any]:
365390
"""Return the buffer as a NumPy array (host memory).
366391
@@ -393,8 +418,8 @@ def byteorder(self) -> Endian:
393418
else:
394419
return Endian(sys.byteorder)
395420

396-
def reshape(self, newshape: Iterable[int]) -> Self:
397-
return self.__class__(self._data.reshape(tuple(newshape)))
421+
def reshape(self, newshape: ChunkCoords) -> Self:
422+
return self.__class__(self._data.reshape(newshape))
398423

399424
def astype(self, dtype: npt.DTypeLike, order: Literal["K", "A", "C", "F"] = "K") -> Self:
400425
return self.__class__(self._data.astype(dtype=dtype, order=order))
@@ -419,8 +444,8 @@ def fill(self, value: Any) -> None:
419444
def copy(self) -> Self:
420445
return self.__class__(self._data.copy())
421446

422-
def transpose(self, *axes: np.SupportsIndex) -> Self: # type: ignore[name-defined]
423-
return self.__class__(self._data.transpose(*axes))
447+
def transpose(self, axes: SupportsIndex | Sequence[SupportsIndex] | None) -> Self:
448+
return self.__class__(self._data.transpose(axes))
424449

425450

426451
def as_numpy_array_wrapper(func: Callable[[npt.NDArray[Any]], bytes], buf: Buffer) -> Buffer:

src/zarr/codecs/bytes.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import numpy as np
99

1010
from zarr.abc.codec import ArrayBytesCodec
11-
from zarr.buffer import Buffer, NDBuffer
11+
from zarr.buffer import Buffer, NDArrayLike, NDBuffer
1212
from zarr.codecs.registry import register_codec
1313
from zarr.common import parse_enum, parse_named_configuration
1414

@@ -75,7 +75,13 @@ async def _decode_single(
7575
dtype = np.dtype(f"{prefix}{chunk_spec.dtype.str[1:]}")
7676
else:
7777
dtype = np.dtype(f"|{chunk_spec.dtype.str[1:]}")
78-
chunk_array = chunk_bytes.as_nd_buffer(dtype=dtype)
78+
79+
as_array_like = chunk_bytes.as_array_like()
80+
if isinstance(as_array_like, NDArrayLike):
81+
as_nd_array_like = as_array_like
82+
else:
83+
as_nd_array_like = np.asanyarray(as_array_like)
84+
chunk_array = NDBuffer.from_ndarray_like(as_nd_array_like.view(dtype=dtype))
7985

8086
# ensure correct chunk shape
8187
if chunk_array.shape != chunk_spec.shape:
@@ -96,7 +102,11 @@ async def _encode_single(
96102
# see https://github.com/numpy/numpy/issues/26473
97103
new_dtype = chunk_array.dtype.newbyteorder(self.endian.name) # type: ignore[arg-type]
98104
chunk_array = chunk_array.astype(new_dtype)
99-
return chunk_array.as_buffer()
105+
106+
as_nd_array_like = chunk_array.as_ndarray_like()
107+
# Flatten the nd-array (only copy if needed)
108+
as_nd_array_like = as_nd_array_like.ravel().view(dtype="b")
109+
return Buffer.from_array_like(as_nd_array_like)
100110

101111
def compute_encoded_size(self, input_byte_length: int, _chunk_spec: ArraySpec) -> int:
102112
return input_byte_length

src/zarr/testing/store.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from zarr.abc.store import Store
44
from zarr.buffer import Buffer
5+
from zarr.testing.utils import assert_bytes_equal
56

67

78
class StoreTests:
@@ -27,7 +28,7 @@ def test_store_capabilities(self, store: Store) -> None:
2728
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
2829
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
2930
await store.set(key, Buffer.from_bytes(data))
30-
assert await store.get(key) == data
31+
assert_bytes_equal(await store.get(key), data)
3132

3233
@pytest.mark.parametrize("key", ["foo/c/0"])
3334
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
@@ -36,11 +37,12 @@ async def test_get_partial_values(self, store: Store, key: str, data: bytes) ->
3637
await store.set(key, Buffer.from_bytes(data))
3738
# read back just part of it
3839
vals = await store.get_partial_values([(key, (0, 2))])
39-
assert vals == [data[0:2]]
40+
assert_bytes_equal(vals[0], data[0:2])
4041

4142
# read back multiple parts of it at once
4243
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
43-
assert vals == [data[0:2], data[2:4]]
44+
assert_bytes_equal(vals[0], data[0:2])
45+
assert_bytes_equal(vals[1], data[2:4])
4446

4547
async def test_exists(self, store: Store) -> None:
4648
assert not await store.exists("foo")

src/zarr/testing/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
from zarr.buffer import Buffer
4+
from zarr.common import BytesLike
5+
6+
7+
def assert_bytes_equal(b1: Buffer | BytesLike | None, b2: Buffer | BytesLike | None) -> None:
8+
"""Help function to assert if two bytes-like or Buffers are equal
9+
10+
Warning
11+
-------
12+
Always copies data, only use for testing and debugging
13+
"""
14+
if isinstance(b1, Buffer):
15+
b1 = b1.to_bytes()
16+
if isinstance(b2, Buffer):
17+
b2 = b2.to_bytes()
18+
assert b1 == b2

tests/v3/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterator
4+
from types import ModuleType
35
from typing import TYPE_CHECKING
46

57
from zarr.common import ZarrFormat
@@ -81,3 +83,10 @@ async def async_group(request: pytest.FixtureRequest, tmpdir) -> AsyncGroup:
8183
exists_ok=False,
8284
)
8385
return agroup
86+
87+
88+
@pytest.fixture(params=["numpy", "cupy"])
89+
def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]:
90+
"""Fixture to parametrize over numpy-like libraries"""
91+
92+
yield pytest.importorskip(request.param)

tests/v3/test_buffer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
import pytest
99

1010
from zarr.array import AsyncArray
11-
from zarr.buffer import NDBuffer
12-
from zarr.store.core import StorePath
13-
from zarr.store.memory import MemoryStore
11+
from zarr.buffer import ArrayLike, NDArrayLike, NDBuffer
1412

1513
if TYPE_CHECKING:
1614
from typing_extensions import Self
@@ -41,12 +39,17 @@ def create(
4139
return ret
4240

4341

42+
def test_nd_array_like(xp):
43+
ary = xp.arange(10)
44+
assert isinstance(ary, ArrayLike)
45+
assert isinstance(ary, NDArrayLike)
46+
47+
4448
@pytest.mark.asyncio
45-
async def test_async_array_factory():
46-
store = StorePath(MemoryStore())
49+
async def test_async_array_factory(store_path):
4750
expect = np.zeros((9, 9), dtype="uint16", order="F")
4851
a = await AsyncArray.create(
49-
store / "test_async_array",
52+
store_path,
5053
shape=expect.shape,
5154
chunk_shape=(5, 5),
5255
dtype=expect.dtype,

0 commit comments

Comments
 (0)