Skip to content

Commit e84057a

Browse files
authored
make shardingcodec pickleable (#2011)
* use tmpdir for test * type annotations * refactor morton decode and remove destructuring in call to max * parametrize sharding codec test by data shape * refactor codec tests * add test for pickling sharding codec, and make it pass * Revert "use tmpdir for test" This reverts commit 6ad2ca6. * move fixtures into conftest.py * Update tests/v3/test_codecs/test_endian.py
1 parent 22e3fc5 commit e84057a

File tree

12 files changed

+1189
-1076
lines changed

12 files changed

+1189
-1076
lines changed

src/zarr/codecs/sharding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,22 @@ def __init__(
324324
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
325325
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
326326

327+
# todo: typedict return type
328+
def __getstate__(self) -> dict[str, Any]:
329+
return self.to_dict()
330+
331+
def __setstate__(self, state: dict[str, Any]) -> None:
332+
config = state["configuration"]
333+
object.__setattr__(self, "chunk_shape", parse_shapelike(config["chunk_shape"]))
334+
object.__setattr__(self, "codecs", parse_codecs(config["codecs"]))
335+
object.__setattr__(self, "index_codecs", parse_codecs(config["index_codecs"]))
336+
object.__setattr__(self, "index_location", parse_index_location(config["index_location"]))
337+
338+
# Use instance-local lru_cache to avoid memory leaks
339+
object.__setattr__(self, "_get_chunk_spec", lru_cache()(self._get_chunk_spec))
340+
object.__setattr__(self, "_get_index_chunk_spec", lru_cache()(self._get_index_chunk_spec))
341+
object.__setattr__(self, "_get_chunks_per_shard", lru_cache()(self._get_chunks_per_shard))
342+
327343
@classmethod
328344
def from_dict(cls, data: dict[str, JSON]) -> Self:
329345
_, configuration_parsed = parse_named_configuration(data, "sharding_indexed")

src/zarr/indexing.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,24 +1220,25 @@ def make_slice_selection(selection: Any) -> list[slice]:
12201220
return ls
12211221

12221222

1223-
def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
1224-
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
1225-
# Inspired by compressed morton code as implemented in Neuroglancer
1226-
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1227-
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
1228-
max_coords_bits = max(*bits)
1229-
input_bit = 0
1230-
input_value = z
1231-
out = [0 for _ in range(len(chunk_shape))]
1232-
1233-
for coord_bit in range(max_coords_bits):
1234-
for dim in range(len(chunk_shape)):
1235-
if coord_bit < bits[dim]:
1236-
bit = (input_value >> input_bit) & 1
1237-
out[dim] |= bit << coord_bit
1238-
input_bit += 1
1239-
return tuple(out)
1223+
def decode_morton(z: int, chunk_shape: ChunkCoords) -> ChunkCoords:
1224+
# Inspired by compressed morton code as implemented in Neuroglancer
1225+
# https://github.com/google/neuroglancer/blob/master/src/neuroglancer/datasource/precomputed/volume.md#compressed-morton-code
1226+
bits = tuple(math.ceil(math.log2(c)) for c in chunk_shape)
1227+
max_coords_bits = max(bits)
1228+
input_bit = 0
1229+
input_value = z
1230+
out = [0] * len(chunk_shape)
1231+
1232+
for coord_bit in range(max_coords_bits):
1233+
for dim in range(len(chunk_shape)):
1234+
if coord_bit < bits[dim]:
1235+
bit = (input_value >> input_bit) & 1
1236+
out[dim] |= bit << coord_bit
1237+
input_bit += 1
1238+
return tuple(out)
1239+
12401240

1241+
def morton_order_iter(chunk_shape: ChunkCoords) -> Iterator[ChunkCoords]:
12411242
for i in range(product(chunk_shape)):
12421243
yield decode_morton(i, chunk_shape)
12431244

tests/v3/conftest.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
from types import ModuleType
55
from typing import TYPE_CHECKING
66

7-
from zarr.common import ZarrFormat
7+
from _pytest.compat import LEGACY_PATH
8+
9+
from zarr.abc.store import Store
10+
from zarr.common import ChunkCoords, MemoryOrder, ZarrFormat
811
from zarr.group import AsyncGroup
912

1013
if TYPE_CHECKING:
1114
from typing import Any, Literal
1215
import pathlib
1316
from dataclasses import dataclass, field
1417

18+
import numpy as np
1519
import pytest
1620

1721
from zarr.store import LocalStore, MemoryStore, StorePath
@@ -26,40 +30,40 @@ def parse_store(
2630
if store == "memory":
2731
return MemoryStore(mode="w")
2832
if store == "remote":
29-
return RemoteStore(mode="w")
33+
return RemoteStore(url=path, mode="w")
3034
raise AssertionError
3135

3236

3337
@pytest.fixture(params=[str, pathlib.Path])
34-
def path_type(request):
38+
def path_type(request: pytest.FixtureRequest) -> Any:
3539
return request.param
3640

3741

3842
# todo: harmonize this with local_store fixture
3943
@pytest.fixture
40-
def store_path(tmpdir):
44+
def store_path(tmpdir: LEGACY_PATH) -> StorePath:
4145
store = LocalStore(str(tmpdir), mode="w")
4246
p = StorePath(store)
4347
return p
4448

4549

4650
@pytest.fixture(scope="function")
47-
def local_store(tmpdir):
51+
def local_store(tmpdir: LEGACY_PATH) -> LocalStore:
4852
return LocalStore(str(tmpdir), mode="w")
4953

5054

5155
@pytest.fixture(scope="function")
52-
def remote_store():
53-
return RemoteStore(mode="w")
56+
def remote_store(url: str) -> RemoteStore:
57+
return RemoteStore(url, mode="w")
5458

5559

5660
@pytest.fixture(scope="function")
57-
def memory_store():
61+
def memory_store() -> MemoryStore:
5862
return MemoryStore(mode="w")
5963

6064

6165
@pytest.fixture(scope="function")
62-
def store(request: str, tmpdir):
66+
def store(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> Store:
6367
param = request.param
6468
return parse_store(param, str(tmpdir))
6569

@@ -72,7 +76,7 @@ class AsyncGroupRequest:
7276

7377

7478
@pytest.fixture(scope="function")
75-
async def async_group(request: pytest.FixtureRequest, tmpdir) -> AsyncGroup:
79+
async def async_group(request: pytest.FixtureRequest, tmpdir: LEGACY_PATH) -> AsyncGroup:
7680
param: AsyncGroupRequest = request.param
7781

7882
store = parse_store(param.store, str(tmpdir))
@@ -90,3 +94,20 @@ def xp(request: pytest.FixtureRequest) -> Iterator[ModuleType]:
9094
"""Fixture to parametrize over numpy-like libraries"""
9195

9296
yield pytest.importorskip(request.param)
97+
98+
99+
@dataclass
100+
class ArrayRequest:
101+
shape: ChunkCoords
102+
dtype: str
103+
order: MemoryOrder
104+
105+
106+
@pytest.fixture
107+
def array_fixture(request: pytest.FixtureRequest) -> np.ndarray:
108+
array_request: ArrayRequest = request.param
109+
return (
110+
np.arange(np.prod(array_request.shape))
111+
.reshape(array_request.shape, order=array_request.order)
112+
.astype(array_request.dtype)
113+
)

0 commit comments

Comments
 (0)