Skip to content

Commit b57ed72

Browse files
committed
added a BatchedCodecPipeline
1 parent 8b882b7 commit b57ed72

File tree

6 files changed

+484
-156
lines changed

6 files changed

+484
-156
lines changed

src/zarr/v3/abc/codec.py

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

33
from abc import abstractmethod
4-
from typing import TYPE_CHECKING, Optional
4+
from typing import TYPE_CHECKING, Awaitable, Callable, Iterable, Optional, Tuple, TypeVar
55

66
import numpy as np
77
from zarr.v3.abc.metadata import Metadata
88

9-
from zarr.v3.common import ArraySpec
9+
from zarr.v3.common import ArraySpec, concurrent_map
1010
from zarr.v3.store import StorePath
1111

1212

@@ -18,6 +18,22 @@
1818
RuntimeConfiguration,
1919
)
2020

21+
T = TypeVar("T")
22+
U = TypeVar("U")
23+
24+
25+
def noop_for_none(
26+
func: Callable[[Optional[T], ArraySpec, RuntimeConfiguration], Awaitable[U]]
27+
) -> Callable[[T, ArraySpec, RuntimeConfiguration], Awaitable[U]]:
28+
async def wrap(
29+
chunk: Optional[T], chunk_spec: ArraySpec, runtime_configuration: RuntimeConfiguration
30+
) -> U:
31+
if chunk is None:
32+
return None
33+
return await func(chunk, chunk_spec, runtime_configuration)
34+
35+
return wrap
36+
2137

2238
class Codec(Metadata):
2339
is_fixed_size: bool
@@ -46,6 +62,20 @@ async def decode(
4662
) -> np.ndarray:
4763
pass
4864

65+
async def decode_batch(
66+
self,
67+
chunk_arrays_and_specs: Iterable[Tuple[np.ndarray, ArraySpec]],
68+
runtime_configuration: RuntimeConfiguration,
69+
) -> Iterable[np.ndarray]:
70+
return await concurrent_map(
71+
[
72+
(chunk_array, chunk_spec, runtime_configuration)
73+
for chunk_array, chunk_spec in chunk_arrays_and_specs
74+
],
75+
noop_for_none(self.decode),
76+
runtime_configuration.concurrency,
77+
)
78+
4979
@abstractmethod
5080
async def encode(
5181
self,
@@ -55,17 +85,45 @@ async def encode(
5585
) -> Optional[np.ndarray]:
5686
pass
5787

88+
async def encode_batch(
89+
self,
90+
chunk_arrays_and_specs: Iterable[Tuple[Optional[np.ndarray], ArraySpec]],
91+
runtime_configuration: RuntimeConfiguration,
92+
) -> Iterable[Optional[np.ndarray]]:
93+
return await concurrent_map(
94+
[
95+
(chunk_array, chunk_spec, runtime_configuration)
96+
for chunk_array, chunk_spec in chunk_arrays_and_specs
97+
],
98+
noop_for_none(self.encode),
99+
runtime_configuration.concurrency,
100+
)
101+
58102

59103
class ArrayBytesCodec(Codec):
60104
@abstractmethod
61105
async def decode(
62106
self,
63-
chunk_array: BytesLike,
107+
chunk_bytes: BytesLike,
64108
chunk_spec: ArraySpec,
65109
runtime_configuration: RuntimeConfiguration,
66110
) -> np.ndarray:
67111
pass
68112

113+
async def decode_batch(
114+
self,
115+
chunk_bytes_and_specs: Iterable[Tuple[BytesLike, ArraySpec]],
116+
runtime_configuration: RuntimeConfiguration,
117+
) -> Iterable[np.ndarray]:
118+
return await concurrent_map(
119+
[
120+
(chunk_bytes, chunk_spec, runtime_configuration)
121+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
122+
],
123+
noop_for_none(self.decode),
124+
runtime_configuration.concurrency,
125+
)
126+
69127
@abstractmethod
70128
async def encode(
71129
self,
@@ -75,6 +133,20 @@ async def encode(
75133
) -> Optional[BytesLike]:
76134
pass
77135

136+
async def encode_batch(
137+
self,
138+
chunk_arrays_and_specs: Iterable[Tuple[Optional[np.ndarray], ArraySpec]],
139+
runtime_configuration: RuntimeConfiguration,
140+
) -> Iterable[Optional[BytesLike]]:
141+
return await concurrent_map(
142+
[
143+
(chunk_array, chunk_spec, runtime_configuration)
144+
for chunk_array, chunk_spec in chunk_arrays_and_specs
145+
],
146+
noop_for_none(self.encode),
147+
runtime_configuration.concurrency,
148+
)
149+
78150

79151
class ArrayBytesCodecPartialDecodeMixin:
80152
@abstractmethod
@@ -87,6 +159,20 @@ async def decode_partial(
87159
) -> Optional[np.ndarray]:
88160
pass
89161

162+
async def decode_partial_batched(
163+
self,
164+
batch_info: Iterable[Tuple[StorePath, SliceSelection, ArraySpec]],
165+
runtime_configuration: RuntimeConfiguration,
166+
) -> Iterable[Optional[np.ndarray]]:
167+
return await concurrent_map(
168+
[
169+
(store_path, selection, chunk_spec, runtime_configuration)
170+
for store_path, selection, chunk_spec in batch_info
171+
],
172+
self.decode_partial,
173+
runtime_configuration.concurrency,
174+
)
175+
90176

91177
class ArrayBytesCodecPartialEncodeMixin:
92178
@abstractmethod
@@ -100,17 +186,45 @@ async def encode_partial(
100186
) -> None:
101187
pass
102188

189+
async def encode_partial_batched(
190+
self,
191+
batch_info: Iterable[Tuple[StorePath, np.ndarray, SliceSelection, ArraySpec]],
192+
runtime_configuration: RuntimeConfiguration,
193+
) -> None:
194+
await concurrent_map(
195+
[
196+
(store_path, chunk_array, selection, chunk_spec, runtime_configuration)
197+
for store_path, chunk_array, selection, chunk_spec in batch_info
198+
],
199+
self.encode_partial,
200+
runtime_configuration.concurrency,
201+
)
202+
103203

104204
class BytesBytesCodec(Codec):
105205
@abstractmethod
106206
async def decode(
107207
self,
108-
chunk_array: BytesLike,
208+
chunk_bytes: BytesLike,
109209
chunk_spec: ArraySpec,
110210
runtime_configuration: RuntimeConfiguration,
111211
) -> BytesLike:
112212
pass
113213

214+
async def decode_batch(
215+
self,
216+
chunk_bytes_and_specs: Iterable[Tuple[BytesLike, ArraySpec]],
217+
runtime_configuration: RuntimeConfiguration,
218+
) -> Iterable[BytesLike]:
219+
return await concurrent_map(
220+
[
221+
(chunk_bytes, chunk_spec, runtime_configuration)
222+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
223+
],
224+
noop_for_none(self.decode),
225+
runtime_configuration.concurrency,
226+
)
227+
114228
@abstractmethod
115229
async def encode(
116230
self,
@@ -119,3 +233,17 @@ async def encode(
119233
runtime_configuration: RuntimeConfiguration,
120234
) -> Optional[BytesLike]:
121235
pass
236+
237+
async def encode_batch(
238+
self,
239+
chunk_bytes_and_specs: Iterable[Tuple[Optional[BytesLike], ArraySpec]],
240+
runtime_configuration: RuntimeConfiguration,
241+
) -> Iterable[Optional[BytesLike]]:
242+
return await concurrent_map(
243+
[
244+
(chunk_bytes, chunk_spec, runtime_configuration)
245+
for chunk_bytes, chunk_spec in chunk_bytes_and_specs
246+
],
247+
noop_for_none(self.encode),
248+
runtime_configuration.concurrency,
249+
)

0 commit comments

Comments
 (0)