Skip to content

Commit 26fa37e

Browse files
committed
Use new ByteRequest syntax
1 parent 6da7976 commit 26fa37e

File tree

1 file changed

+52
-36
lines changed

1 file changed

+52
-36
lines changed

src/zarr/storage/object_store.py

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88

99
import obstore as obs
1010

11-
from zarr.abc.store import ByteRangeRequest, Store
11+
from zarr.abc.store import (
12+
ByteRequest,
13+
OffsetByteRequest,
14+
RangeByteRequest,
15+
Store,
16+
SuffixByteRequest,
17+
)
1218
from zarr.core.buffer import Buffer
1319
from zarr.core.buffer.core import BufferPrototype
1420

@@ -64,36 +70,33 @@ def __repr__(self) -> str:
6470
return f"ObjectStore({self})"
6571

6672
async def get(
67-
self, key: str, prototype: BufferPrototype, byte_range: ByteRangeRequest | None = None
73+
self, key: str, prototype: BufferPrototype, byte_range: ByteRequest | None = None
6874
) -> Buffer:
6975
if byte_range is None:
7076
resp = await obs.get_async(self.store, key)
7177
return prototype.buffer.from_bytes(await resp.bytes_async())
72-
73-
start, end = byte_range
74-
if (start is None or start == 0) and end is None:
75-
resp = await obs.get_async(self.store, key)
76-
return prototype.buffer.from_bytes(await resp.bytes_async())
77-
if start is not None and end is not None:
78-
resp = await obs.get_range_async(self.store, key, start=start, end=end)
78+
elif isinstance(byte_range, RangeByteRequest):
79+
resp = await obs.get_range_async(
80+
self.store, key, start=byte_range.start, end=byte_range.end
81+
)
7982
return prototype.buffer.from_bytes(memoryview(resp))
80-
elif start is not None:
81-
if start > 0:
82-
# Offset request
83-
resp = await obs.get_async(self.store, key, options={"range": {"offset": start}})
84-
else:
85-
resp = await obs.get_async(self.store, key, options={"range": {"suffix": start}})
83+
elif isinstance(byte_range, OffsetByteRequest):
84+
resp = await obs.get_async(
85+
self.store, key, options={"range": {"offset": byte_range.offset}}
86+
)
87+
return prototype.buffer.from_bytes(await resp.bytes_async())
88+
elif isinstance(byte_range, SuffixByteRequest):
89+
resp = await obs.get_async(
90+
self.store, key, options={"range": {"suffix": byte_range.suffix}}
91+
)
8692
return prototype.buffer.from_bytes(await resp.bytes_async())
87-
elif end is not None:
88-
resp = await obs.get_range_async(self.store, key, start=0, end=end)
89-
return prototype.buffer.from_bytes(memoryview(resp))
9093
else:
91-
raise ValueError(f"Unexpected input to `get`: {start=}, {end=}")
94+
raise ValueError(f"Unexpected input to `get`: {byte_range}")
9295

9396
async def get_partial_values(
9497
self,
9598
prototype: BufferPrototype,
96-
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
99+
key_ranges: Iterable[tuple[str, ByteRequest | None]],
97100
) -> list[Buffer | None]:
98101
return await _get_partial_values(self.store, prototype=prototype, key_ranges=key_ranges)
99102

@@ -260,7 +263,10 @@ async def _make_other_request(
260263
We return a `list[_Response]` for symmetry with `_make_bounded_requests` so that all
261264
futures can be gathered together.
262265
"""
263-
resp = await obs.get_async(store, request["path"], options={"range": request["range"]})
266+
if request["range"] is None:
267+
resp = await obs.get_async(store, request["path"])
268+
else:
269+
resp = await obs.get_async(store, request["path"], options={"range": request["range"]})
264270
buffer = await resp.bytes_async()
265271
return [
266272
{
@@ -273,7 +279,7 @@ async def _make_other_request(
273279
async def _get_partial_values(
274280
store: obs.store.ObjectStore,
275281
prototype: BufferPrototype,
276-
key_ranges: Iterable[tuple[str, ByteRangeRequest]],
282+
key_ranges: Iterable[tuple[str, ByteRequest | None]],
277283
) -> list[Buffer | None]:
278284
"""Make multiple range requests.
279285
@@ -290,27 +296,37 @@ async def _get_partial_values(
290296
per_file_bounded_requests: dict[str, list[_BoundedRequest]] = defaultdict(list)
291297
other_requests: list[_OtherRequest] = []
292298

293-
for idx, (path, (start, end)) in enumerate(key_ranges):
294-
if start is None:
295-
raise ValueError("Cannot pass `None` for the start of the range request.")
296-
297-
if end is not None:
298-
# This is a bounded request with known start and end byte.
299+
for idx, (path, byte_range) in enumerate(key_ranges):
300+
if byte_range is None:
301+
other_requests.append(
302+
{
303+
"original_request_index": idx,
304+
"path": path,
305+
"range": None,
306+
}
307+
)
308+
elif isinstance(byte_range, RangeByteRequest):
299309
per_file_bounded_requests[path].append(
300-
{"original_request_index": idx, "start": start, "end": end}
310+
{"original_request_index": idx, "start": byte_range.start, "end": byte_range.end}
301311
)
302-
elif start < 0:
303-
# Suffix request from the end
312+
elif isinstance(byte_range, OffsetByteRequest):
304313
other_requests.append(
305-
{"original_request_index": idx, "path": path, "range": {"suffix": abs(start)}}
314+
{
315+
"original_request_index": idx,
316+
"path": path,
317+
"range": {"offset": byte_range.offset},
318+
}
306319
)
307-
elif start >= 0:
308-
# Offset request to the end
320+
elif isinstance(byte_range, SuffixByteRequest):
309321
other_requests.append(
310-
{"original_request_index": idx, "path": path, "range": {"offset": start}}
322+
{
323+
"original_request_index": idx,
324+
"path": path,
325+
"range": {"suffix": byte_range.suffix},
326+
}
311327
)
312328
else:
313-
raise ValueError(f"Unsupported range input: {start=}, {end=}")
329+
raise ValueError(f"Unsupported range input: {byte_range}")
314330

315331
futs: list[Coroutine[Any, Any, list[_Response]]] = []
316332
for path, bounded_ranges in per_file_bounded_requests.items():

0 commit comments

Comments
 (0)