Skip to content

Commit c677da4

Browse files
madsbkd-v-b
andauthored
[v3] Buffer ensure correct subclass based on the BufferPrototype argument (#1974)
* impl. and use Buffer.from_buffer() * Update src/zarr/buffer.py Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com> * Apply suggestions from code review Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com> --------- Co-authored-by: Davis Bennett <davis.v.bennett@gmail.com>
1 parent e3ee09e commit c677da4

File tree

5 files changed

+33
-7
lines changed

5 files changed

+33
-7
lines changed

src/zarr/buffer.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def create_zero_length(cls) -> Self:
146146

147147
@classmethod
148148
def from_array_like(cls, array_like: ArrayLike) -> Self:
149-
"""Create a new buffer of a array-like object
149+
"""Create a new buffer of an array-like object
150150
151151
Parameters
152152
----------
@@ -159,6 +159,29 @@ def from_array_like(cls, array_like: ArrayLike) -> Self:
159159
"""
160160
return cls(array_like)
161161

162+
@classmethod
163+
def from_buffer(cls, buffer: Buffer) -> Self:
164+
"""Create a new buffer of an existing Buffer
165+
166+
This is useful if you want to ensure that an existing buffer is
167+
of the correct subclass of Buffer. E.g., MemoryStore uses this
168+
to return a buffer instance of the subclass specified by its
169+
BufferPrototype argument.
170+
171+
Typically, this only copies data if the data has to be moved between
172+
memory types, such as from host to device memory.
173+
174+
Parameters
175+
----------
176+
buffer
177+
buffer object.
178+
179+
Returns
180+
-------
181+
A new buffer representing the content of the input buffer
182+
"""
183+
return cls.from_array_like(buffer.as_array_like())
184+
162185
@classmethod
163186
def from_bytes(cls, bytes_like: BytesLike) -> Self:
164187
"""Create a new buffer of a bytes-like object (host memory)

src/zarr/store/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def get(
3939
try:
4040
value = self._store_dict[key]
4141
start, length = _normalize_interval_index(value, byte_range)
42-
return value[start : start + length]
42+
return prototype.buffer.from_buffer(value[start : start + length])
4343
except KeyError:
4444
return None
4545

src/zarr/store/remote.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import fsspec
77

88
from zarr.abc.store import Store
9-
from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype
9+
from zarr.buffer import Buffer, BufferPrototype
1010
from zarr.common import OpenMode
1111
from zarr.store.core import _dereference_path
1212

@@ -84,7 +84,7 @@ def __repr__(self) -> str:
8484
async def get(
8585
self,
8686
key: str,
87-
prototype: BufferPrototype = default_buffer_prototype,
87+
prototype: BufferPrototype,
8888
byte_range: tuple[int | None, int | None] | None = None,
8989
) -> Buffer | None:
9090
path = _dereference_path(self.path, key)
@@ -99,7 +99,7 @@ async def get(
9999
end = length
100100
else:
101101
end = None
102-
value: Buffer = prototype.buffer.from_bytes(
102+
value = prototype.buffer.from_bytes(
103103
await (
104104
self._fs._cat_file(path, start=byte_range[0], end=end)
105105
if byte_range

tests/v3/test_buffer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ async def get(
6868
) -> Buffer | None:
6969
if "json" not in key:
7070
assert prototype.buffer is MyBuffer
71-
return await super().get(key, byte_range)
71+
ret = await super().get(key=key, prototype=prototype, byte_range=byte_range)
72+
if ret is not None:
73+
assert isinstance(ret, prototype.buffer)
74+
return ret
7275

7376

7477
def test_nd_array_like(xp):

tests/v3/test_store/test_remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def test_basic():
8888
data = b"hello"
8989
await store.set("foo", Buffer.from_bytes(data))
9090
assert await store.exists("foo")
91-
assert (await store.get("foo")).to_bytes() == data
91+
assert (await store.get("foo", prototype=default_buffer_prototype)).to_bytes() == data
9292
out = await store.get_partial_values(
9393
prototype=default_buffer_prototype, key_ranges=[("foo", (1, None))]
9494
)

0 commit comments

Comments
 (0)