Skip to content

Commit b825b4f

Browse files
committed
split codecs in 3 groups in CodecPipeline
1 parent c8ec169 commit b825b4f

File tree

3 files changed

+147
-70
lines changed

3 files changed

+147
-70
lines changed

zarr/v3/array.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from attr import evolve, frozen
1919

2020
from zarr.v3.abc.array import SynchronousArray, AsynchronousArray
21-
from zarr.v3.abc.codec import ArrayBytesCodecPartialDecodeMixin, ArrayBytesCodecPartialEncodeMixin
2221

2322
# from zarr.v3.array_v2 import ArrayV2
2423
from zarr.v3.codecs import CodecMetadata, CodecPipeline, bytes_codec
@@ -120,7 +119,7 @@ async def create(
120119
metadata=metadata,
121120
store_path=store_path,
122121
runtime_configuration=runtime_configuration,
123-
codec_pipeline=CodecPipeline(
122+
codec_pipeline=CodecPipeline.create(
124123
[
125124
get_codec_from_metadata(codec).evolve(ndim=len(shape), data_type=data_type)
126125
for codec in codecs
@@ -149,7 +148,7 @@ def from_json(
149148
metadata=metadata,
150149
store_path=store_path,
151150
runtime_configuration=runtime_configuration,
152-
codec_pipeline=CodecPipeline(codecs),
151+
codec_pipeline=CodecPipeline.create(codecs),
153152
)
154153
async_array._validate_metadata()
155154
return async_array
@@ -263,10 +262,8 @@ async def _read_chunk(
263262
chunk_key = chunk_key_encoding.encode_chunk_key(chunk_coords)
264263
store_path = self.store_path / chunk_key
265264

266-
if len(self.codec_pipeline.codecs) == 1 and isinstance(
267-
self.codec_pipeline.codecs[0], ArrayBytesCodecPartialDecodeMixin
268-
):
269-
chunk_array = await self.codec_pipeline.codecs[0].decode_partial(
265+
if self.codec_pipeline.supports_partial_decode:
266+
chunk_array = await self.codec_pipeline.decode_partial(
270267
store_path, chunk_selection, chunk_metadata, self.runtime_configuration
271268
)
272269
if chunk_array is not None:
@@ -346,12 +343,9 @@ async def _write_chunk(
346343
chunk_array = value[out_selection]
347344
await self._write_chunk_to_store(store_path, chunk_array, chunk_metadata)
348345

349-
elif len(self.codec_pipeline.codecs) == 1 and isinstance(
350-
self.codec_pipeline.codecs[0], ArrayBytesCodecPartialEncodeMixin
351-
):
352-
codec_with_partial_encode = self.codec_pipeline.codecs[0]
346+
elif self.codec_pipeline.supports_partial_encode:
353347
# print("encode_partial", chunk_coords, chunk_selection, repr(self))
354-
await codec_with_partial_encode.encode_partial(
348+
await self.codec_pipeline.encode_partial(
355349
store_path,
356350
value[out_selection],
357351
chunk_selection,

zarr/v3/codecs/__init__.py

Lines changed: 137 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,17 @@
1515

1616
import numpy as np
1717

18-
from zarr.v3.abc.codec import Codec, ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec
19-
from zarr.v3.common import BytesLike
18+
from zarr.v3.abc.codec import (
19+
ArrayBytesCodecPartialDecodeMixin,
20+
ArrayBytesCodecPartialEncodeMixin,
21+
Codec,
22+
ArrayArrayCodec,
23+
ArrayBytesCodec,
24+
BytesBytesCodec,
25+
)
26+
from zarr.v3.common import BytesLike, SliceSelection
2027
from zarr.v3.metadata import CodecMetadata, ShardingCodecIndexLocation, RuntimeConfiguration
28+
from zarr.v3.store import StorePath
2129

2230
if TYPE_CHECKING:
2331
from zarr.v3.metadata import ArrayMetadata, ArraySpec
@@ -33,26 +41,28 @@
3341
def _find_array_bytes_codec(
3442
codecs: Iterable[Tuple[Codec, ArraySpec]]
3543
) -> Tuple[ArrayBytesCodec, ArraySpec]:
36-
for codec, chunk_metadata in codecs:
44+
for codec, array_spec in codecs:
3745
if isinstance(codec, ArrayBytesCodec):
38-
return (codec, chunk_metadata)
46+
return (codec, array_spec)
3947
raise KeyError
4048

4149

4250
@frozen
4351
class CodecPipeline:
44-
codecs: List[Codec]
52+
array_array_codecs: List[ArrayArrayCodec]
53+
array_bytes_codec: ArrayBytesCodec
54+
bytes_bytes_codecs: List[BytesBytesCodec]
4555

46-
def validate(self, array_metadata: ArrayMetadata) -> None:
56+
@classmethod
57+
def create(cls, codecs: List[Codec]) -> CodecPipeline:
4758
from zarr.v3.codecs.sharding import ShardingCodec
4859

4960
assert any(
50-
isinstance(codec, ArrayBytesCodec) for codec in self.codecs
61+
isinstance(codec, ArrayBytesCodec) for codec in codecs
5162
), "Exactly one array-to-bytes codec is required."
5263

5364
prev_codec: Optional[Codec] = None
54-
for codec in self.codecs:
55-
codec.validate(array_metadata)
65+
for codec in codecs:
5666
if prev_codec is not None:
5767
assert not isinstance(codec, ArrayBytesCodec) or not isinstance(
5868
prev_codec, ArrayBytesCodec
@@ -81,84 +91,155 @@ def validate(self, array_metadata: ArrayMetadata) -> None:
8191
)
8292
prev_codec = codec
8393

84-
if any(isinstance(codec, ShardingCodec) for codec in self.codecs) and len(self.codecs) > 1:
94+
if any(isinstance(codec, ShardingCodec) for codec in codecs) and len(codecs) > 1:
8595
warn(
8696
"Combining a `sharding_indexed` codec disables partial reads and "
8797
+ "writes, which may lead to inefficient performance."
8898
)
8999

100+
return CodecPipeline(
101+
array_array_codecs=[codec for codec in codecs if isinstance(codec, ArrayArrayCodec)],
102+
array_bytes_codec=[codec for codec in codecs if isinstance(codec, ArrayBytesCodec)][0],
103+
bytes_bytes_codecs=[codec for codec in codecs if isinstance(codec, BytesBytesCodec)],
104+
)
105+
106+
@property
107+
def supports_partial_decode(self) -> bool:
108+
return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance(
109+
self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin
110+
)
111+
112+
@property
113+
def supports_partial_encode(self) -> bool:
114+
return (len(self.array_array_codecs) + len(self.bytes_bytes_codecs)) == 0 and isinstance(
115+
self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin
116+
)
117+
118+
def all_codecs(self) -> Iterator[Codec]:
119+
for aa_codec in self.array_array_codecs:
120+
yield aa_codec
121+
122+
yield self.array_bytes_codec
123+
124+
for bb_codec in self.bytes_bytes_codecs:
125+
yield bb_codec
126+
127+
def validate(self, array_metadata: ArrayMetadata) -> None:
128+
for codec in self.all_codecs():
129+
codec.validate(array_metadata)
130+
90131
def _codecs_with_resolved_metadata(
91-
self, chunk_metadata: ArraySpec
92-
) -> Iterator[Tuple[Codec, ArraySpec]]:
93-
for codec in self.codecs:
94-
yield (codec, chunk_metadata)
95-
chunk_metadata = codec.resolve_metadata(chunk_metadata)
132+
self, array_spec: ArraySpec
133+
) -> Tuple[
134+
List[Tuple[ArrayArrayCodec, ArraySpec]],
135+
Tuple[ArrayBytesCodec, ArraySpec],
136+
List[Tuple[BytesBytesCodec, ArraySpec]],
137+
]:
138+
aa_codecs_with_spec: List[Tuple[ArrayArrayCodec, ArraySpec]] = []
139+
for aa_codec in self.array_array_codecs:
140+
aa_codecs_with_spec.append((aa_codec, array_spec))
141+
array_spec = aa_codec.resolve_metadata(array_spec)
142+
143+
ab_codec_with_spec = (self.array_bytes_codec, array_spec)
144+
array_spec = self.array_bytes_codec.resolve_metadata(array_spec)
145+
146+
bb_codecs_with_spec: List[Tuple[BytesBytesCodec, ArraySpec]] = []
147+
for bb_codec in self.bytes_bytes_codecs:
148+
bb_codecs_with_spec.append((bb_codec, array_spec))
149+
array_spec = bb_codec.resolve_metadata(array_spec)
150+
151+
return (aa_codecs_with_spec, ab_codec_with_spec, bb_codecs_with_spec)
96152

97153
async def decode(
98154
self,
99155
chunk_bytes: BytesLike,
100-
chunk_metadata: ArraySpec,
156+
array_spec: ArraySpec,
101157
runtime_configuration: RuntimeConfiguration,
102158
) -> np.ndarray:
103-
codecs = list(self._codecs_with_resolved_metadata(chunk_metadata))[::-1]
159+
(
160+
aa_codecs_with_spec,
161+
ab_codec_with_spec,
162+
bb_codecs_with_spec,
163+
) = self._codecs_with_resolved_metadata(array_spec)
104164

105-
for bb_codec, chunk_metadata in codecs:
106-
if isinstance(bb_codec, BytesBytesCodec):
107-
chunk_bytes = await bb_codec.decode(
108-
chunk_bytes, chunk_metadata, runtime_configuration
109-
)
165+
for bb_codec, array_spec in bb_codecs_with_spec[::-1]:
166+
chunk_bytes = await bb_codec.decode(chunk_bytes, array_spec, runtime_configuration)
110167

111-
ab_codec, chunk_metadata = _find_array_bytes_codec(codecs)
112-
chunk_array = await ab_codec.decode(chunk_bytes, chunk_metadata, runtime_configuration)
168+
ab_codec, array_spec = ab_codec_with_spec
169+
chunk_array = await ab_codec.decode(chunk_bytes, array_spec, runtime_configuration)
113170

114-
for aa_codec, chunk_metadata in codecs:
115-
if isinstance(aa_codec, ArrayArrayCodec):
116-
chunk_array = await aa_codec.decode(
117-
chunk_array, chunk_metadata, runtime_configuration
118-
)
171+
for aa_codec, array_spec in aa_codecs_with_spec[::-1]:
172+
chunk_array = await aa_codec.decode(chunk_array, array_spec, runtime_configuration)
119173

120174
return chunk_array
121175

176+
async def decode_partial(
177+
self,
178+
store_path: StorePath,
179+
selection: SliceSelection,
180+
chunk_metadata: ArraySpec,
181+
runtime_configuration: RuntimeConfiguration,
182+
) -> Optional[np.ndarray]:
183+
assert self.supports_partial_decode
184+
assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialDecodeMixin)
185+
return await self.array_bytes_codec.decode_partial(
186+
store_path, selection, chunk_metadata, runtime_configuration
187+
)
188+
122189
async def encode(
123190
self,
124191
chunk_array: np.ndarray,
125-
chunk_metadata: ArraySpec,
192+
array_spec: ArraySpec,
126193
runtime_configuration: RuntimeConfiguration,
127194
) -> Optional[BytesLike]:
128-
codecs = list(self._codecs_with_resolved_metadata(chunk_metadata))
129-
130-
for aa_codec, chunk_metadata in codecs:
131-
if isinstance(aa_codec, ArrayArrayCodec):
132-
chunk_array_maybe = await aa_codec.encode(
133-
chunk_array, chunk_metadata, runtime_configuration
134-
)
135-
if chunk_array_maybe is None:
136-
return None
137-
chunk_array = chunk_array_maybe
195+
(
196+
aa_codecs_with_spec,
197+
ab_codec_with_spec,
198+
bb_codecs_with_spec,
199+
) = self._codecs_with_resolved_metadata(array_spec)
200+
201+
for aa_codec, array_spec in aa_codecs_with_spec:
202+
chunk_array_maybe = await aa_codec.encode(
203+
chunk_array, array_spec, runtime_configuration
204+
)
205+
if chunk_array_maybe is None:
206+
return None
207+
chunk_array = chunk_array_maybe
138208

139-
ab_codec, chunk_metadata = _find_array_bytes_codec(codecs)
140-
chunk_bytes_maybe = await ab_codec.encode(
141-
chunk_array, chunk_metadata, runtime_configuration
142-
)
209+
ab_codec, array_spec = ab_codec_with_spec
210+
chunk_bytes_maybe = await ab_codec.encode(chunk_array, array_spec, runtime_configuration)
143211
if chunk_bytes_maybe is None:
144212
return None
145213
chunk_bytes = chunk_bytes_maybe
146214

147-
for bb_codec, chunk_metadata in codecs:
148-
if isinstance(bb_codec, BytesBytesCodec):
149-
chunk_bytes_maybe = await bb_codec.encode(
150-
chunk_bytes, chunk_metadata, runtime_configuration
151-
)
152-
if chunk_bytes_maybe is None:
153-
return None
154-
chunk_bytes = chunk_bytes_maybe
215+
for bb_codec, array_spec in bb_codecs_with_spec:
216+
chunk_bytes_maybe = await bb_codec.encode(
217+
chunk_bytes, array_spec, runtime_configuration
218+
)
219+
if chunk_bytes_maybe is None:
220+
return None
221+
chunk_bytes = chunk_bytes_maybe
155222

156223
return chunk_bytes
157224

158-
def compute_encoded_size(self, byte_length: int, chunk_metadata: ArraySpec) -> int:
159-
for codec in self.codecs:
160-
byte_length = codec.compute_encoded_size(byte_length, chunk_metadata)
161-
chunk_metadata = codec.resolve_metadata(chunk_metadata)
225+
async def encode_partial(
226+
self,
227+
store_path: StorePath,
228+
chunk_array: np.ndarray,
229+
selection: SliceSelection,
230+
chunk_metadata: ArraySpec,
231+
runtime_configuration: RuntimeConfiguration,
232+
) -> None:
233+
assert self.supports_partial_encode
234+
assert isinstance(self.array_bytes_codec, ArrayBytesCodecPartialEncodeMixin)
235+
await self.array_bytes_codec.encode_partial(
236+
store_path, chunk_array, selection, chunk_metadata, runtime_configuration
237+
)
238+
239+
def compute_encoded_size(self, byte_length: int, array_spec: ArraySpec) -> int:
240+
for codec in self.all_codecs():
241+
byte_length = codec.compute_encoded_size(byte_length, array_spec)
242+
array_spec = codec.resolve_metadata(array_spec)
162243
return byte_length
163244

164245

zarr/v3/codecs/sharding.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,11 +624,13 @@ def _get_chunks_per_shard(self, shard_metadata: ArraySpec) -> ChunkCoords:
624624

625625
@cached_property
626626
def _index_codec_pipeline(self) -> CodecPipeline:
627-
return CodecPipeline([get_codec_from_metadata(c) for c in self.configuration.index_codecs])
627+
return CodecPipeline.create(
628+
[get_codec_from_metadata(c) for c in self.configuration.index_codecs]
629+
)
628630

629631
@cached_property
630632
def _codec_pipeline(self) -> CodecPipeline:
631-
return CodecPipeline([get_codec_from_metadata(c) for c in self.configuration.codecs])
633+
return CodecPipeline.create([get_codec_from_metadata(c) for c in self.configuration.codecs])
632634

633635
async def _load_shard_index_maybe(
634636
self, store_path: StorePath, chunks_per_shard: ChunkCoords

0 commit comments

Comments
 (0)