Skip to content

Commit fc7fa4f

Browse files
authored
[V3] Expand store tests (#1900)
* Fill in some test methods with NotImplementedError to force implementations to implement them; make StoreTests generic w.r.t. the store class being tested; update store.get abc to match actual type signature * remove auto_mkdir from LocalStore; add set and get methods to StoreTests class * fix: use from_bytes method on buffer * fix: use Buffer instead of bytes for store tests * docstrings, add some Nones to test_get_partial_values; normalize function signatures
1 parent b1f4c50 commit fc7fa4f

File tree

7 files changed

+179
-818
lines changed

7 files changed

+179
-818
lines changed

src/zarr/abc/store.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
class Store(ABC):
1010
@abstractmethod
1111
async def get(
12-
self, key: str, byte_range: tuple[int, int | None] | None = None
12+
self, key: str, byte_range: tuple[int | None, int | None] | None = None
1313
) -> Buffer | None:
1414
"""Retrieve the value associated with a given key.
1515
@@ -26,7 +26,7 @@ async def get(
2626

2727
@abstractmethod
2828
async def get_partial_values(
29-
self, key_ranges: list[tuple[str, tuple[int, int]]]
29+
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
3030
) -> list[Buffer | None]:
3131
"""Retrieve possibly partial values from given key_ranges.
3232

src/zarr/store/core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,27 @@ def make_store_path(store_like: StoreLike) -> StorePath:
6868
elif isinstance(store_like, str):
6969
return StorePath(LocalStore(Path(store_like)))
7070
raise TypeError
71+
72+
73+
def _normalize_interval_index(
74+
data: Buffer, interval: None | tuple[int | None, int | None]
75+
) -> tuple[int, int]:
76+
"""
77+
Convert an implicit interval into an explicit start and length
78+
"""
79+
if interval is None:
80+
start = 0
81+
length = len(data)
82+
else:
83+
maybe_start, maybe_len = interval
84+
if maybe_start is None:
85+
start = 0
86+
else:
87+
start = maybe_start
88+
89+
if maybe_len is None:
90+
length = len(data) - start
91+
else:
92+
length = maybe_len
93+
94+
return (start, length)

src/zarr/store/local.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from zarr.common import concurrent_map, to_thread
1111

1212

13-
def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:
13+
def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer:
1414
"""
1515
Fetch a contiguous region of bytes from a file.
1616
@@ -51,10 +51,8 @@ def _put(
5151
path: Path,
5252
value: Buffer,
5353
start: int | None = None,
54-
auto_mkdir: bool = True,
5554
) -> int | None:
56-
if auto_mkdir:
57-
path.parent.mkdir(parents=True, exist_ok=True)
55+
path.parent.mkdir(parents=True, exist_ok=True)
5856
if start is not None:
5957
with path.open("r+b") as f:
6058
f.seek(start)
@@ -70,15 +68,13 @@ class LocalStore(Store):
7068
supports_listing: bool = True
7169

7270
root: Path
73-
auto_mkdir: bool
7471

75-
def __init__(self, root: Path | str, auto_mkdir: bool = True):
72+
def __init__(self, root: Path | str):
7673
if isinstance(root, str):
7774
root = Path(root)
7875
assert isinstance(root, Path)
7976

8077
self.root = root
81-
self.auto_mkdir = auto_mkdir
8278

8379
def __str__(self) -> str:
8480
return f"file://{self.root}"
@@ -90,7 +86,7 @@ def __eq__(self, other: object) -> bool:
9086
return isinstance(other, type(self)) and self.root == other.root
9187

9288
async def get(
93-
self, key: str, byte_range: tuple[int, int | None] | None = None
89+
self, key: str, byte_range: tuple[int | None, int | None] | None = None
9490
) -> Buffer | None:
9591
assert isinstance(key, str)
9692
path = self.root / key
@@ -101,7 +97,7 @@ async def get(
10197
return None
10298

10399
async def get_partial_values(
104-
self, key_ranges: list[tuple[str, tuple[int, int]]]
100+
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
105101
) -> list[Buffer | None]:
106102
"""
107103
Read byte ranges from multiple keys.
@@ -128,7 +124,7 @@ async def set(self, key: str, value: Buffer) -> None:
128124
if not isinstance(value, Buffer):
129125
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
130126
path = self.root / key
131-
await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir)
127+
await to_thread(_put, path, value)
132128

133129
async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
134130
args = []

src/zarr/store/memory.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from zarr.abc.store import Store
66
from zarr.buffer import Buffer
77
from zarr.common import concurrent_map
8+
from zarr.store.core import _normalize_interval_index
89

910

1011
# TODO: this store could easily be extended to wrap any MutableMapping store from v2
@@ -26,19 +27,18 @@ def __repr__(self) -> str:
2627
return f"MemoryStore({str(self)!r})"
2728

2829
async def get(
29-
self, key: str, byte_range: tuple[int, int | None] | None = None
30+
self, key: str, byte_range: tuple[int | None, int | None] | None = None
3031
) -> Buffer | None:
3132
assert isinstance(key, str)
3233
try:
3334
value = self._store_dict[key]
34-
if byte_range is not None:
35-
value = value[byte_range[0] : byte_range[1]]
36-
return value
35+
start, length = _normalize_interval_index(value, byte_range)
36+
return value[start : start + length]
3737
except KeyError:
3838
return None
3939

4040
async def get_partial_values(
41-
self, key_ranges: list[tuple[str, tuple[int, int]]]
41+
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
4242
) -> list[Buffer | None]:
4343
vals = await concurrent_map(key_ranges, self.get, limit=None)
4444
return vals

src/zarr/store/remote.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]:
4949
return fs, root
5050

5151
async def get(
52-
self, key: str, byte_range: tuple[int, int | None] | None = None
52+
self, key: str, byte_range: tuple[int | None, int | None] | None = None
5353
) -> Buffer | None:
5454
assert isinstance(key, str)
5555
fs, root = self._make_fs()

src/zarr/testing/store.py

Lines changed: 97 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,130 @@
1+
from typing import Generic, TypeVar
2+
13
import pytest
24

35
from zarr.abc.store import Store
46
from zarr.buffer import Buffer
7+
from zarr.store.core import _normalize_interval_index
58
from zarr.testing.utils import assert_bytes_equal
69

10+
S = TypeVar("S", bound=Store)
11+
12+
13+
class StoreTests(Generic[S]):
14+
store_cls: type[S]
715

8-
class StoreTests:
9-
store_cls: type[Store]
16+
def set(self, store: S, key: str, value: Buffer) -> None:
17+
"""
18+
Insert a value into a storage backend, with a specific key.
19+
This should not not use any store methods. Bypassing the store methods allows them to be
20+
tested.
21+
"""
22+
raise NotImplementedError
23+
24+
def get(self, store: S, key: str) -> Buffer:
25+
"""
26+
Retrieve a value from a storage backend, by key.
27+
This should not not use any store methods. Bypassing the store methods allows them to be
28+
tested.
29+
"""
30+
31+
raise NotImplementedError
1032

1133
@pytest.fixture(scope="function")
1234
def store(self) -> Store:
1335
return self.store_cls()
1436

15-
def test_store_type(self, store: Store) -> None:
37+
def test_store_type(self, store: S) -> None:
1638
assert isinstance(store, Store)
1739
assert isinstance(store, self.store_cls)
1840

19-
def test_store_repr(self, store: Store) -> None:
20-
assert repr(store)
41+
def test_store_repr(self, store: S) -> None:
42+
raise NotImplementedError
43+
44+
def test_store_supports_writes(self, store: S) -> None:
45+
raise NotImplementedError
2146

22-
def test_store_capabilities(self, store: Store) -> None:
23-
assert store.supports_writes
24-
assert store.supports_partial_writes
25-
assert store.supports_listing
47+
def test_store_supports_partial_writes(self, store: S) -> None:
48+
raise NotImplementedError
49+
50+
def test_store_supports_listing(self, store: S) -> None:
51+
raise NotImplementedError
2652

2753
@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
2854
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
29-
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
30-
await store.set(key, Buffer.from_bytes(data))
31-
assert_bytes_equal(await store.get(key), data)
32-
33-
@pytest.mark.parametrize("key", ["foo/c/0"])
55+
@pytest.mark.parametrize("byte_range", (None, (0, None), (1, None), (1, 2), (None, 1)))
56+
async def test_get(
57+
self, store: S, key: str, data: bytes, byte_range: None | tuple[int | None, int | None]
58+
) -> None:
59+
"""
60+
Ensure that data can be read from the store using the store.get method.
61+
"""
62+
data_buf = Buffer.from_bytes(data)
63+
self.set(store, key, data_buf)
64+
observed = await store.get(key, byte_range=byte_range)
65+
start, length = _normalize_interval_index(data_buf, interval=byte_range)
66+
expected = data_buf[start : start + length]
67+
assert_bytes_equal(observed, expected)
68+
69+
@pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"])
3470
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
35-
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
71+
async def test_set(self, store: S, key: str, data: bytes) -> None:
72+
"""
73+
Ensure that data can be written to the store using the store.set method.
74+
"""
75+
data_buf = Buffer.from_bytes(data)
76+
await store.set(key, data_buf)
77+
observed = self.get(store, key)
78+
assert_bytes_equal(observed, data_buf)
79+
80+
@pytest.mark.parametrize(
81+
"key_ranges",
82+
(
83+
[],
84+
[("zarr.json", (0, 1))],
85+
[("c/0", (0, 1)), ("zarr.json", (0, None))],
86+
[("c/0/0", (0, 1)), ("c/0/1", (None, 2)), ("c/0/2", (0, 3))],
87+
),
88+
)
89+
async def test_get_partial_values(
90+
self, store: S, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
91+
) -> None:
3692
# put all of the data
37-
await store.set(key, Buffer.from_bytes(data))
93+
for key, _ in key_ranges:
94+
self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8")))
95+
3896
# read back just part of it
39-
vals = await store.get_partial_values([(key, (0, 2))])
40-
assert_bytes_equal(vals[0], data[0:2])
97+
observed_maybe = await store.get_partial_values(key_ranges=key_ranges)
98+
99+
observed: list[Buffer] = []
100+
expected: list[Buffer] = []
101+
102+
for obs in observed_maybe:
103+
assert obs is not None
104+
observed.append(obs)
105+
106+
for idx in range(len(observed)):
107+
key, byte_range = key_ranges[idx]
108+
result = await store.get(key, byte_range=byte_range)
109+
assert result is not None
110+
expected.append(result)
41111

42-
# read back multiple parts of it at once
43-
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
44-
assert_bytes_equal(vals[0], data[0:2])
45-
assert_bytes_equal(vals[1], data[2:4])
112+
assert all(
113+
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
114+
)
46115

47-
async def test_exists(self, store: Store) -> None:
116+
async def test_exists(self, store: S) -> None:
48117
assert not await store.exists("foo")
49118
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
50119
assert await store.exists("foo/zarr.json")
51120

52-
async def test_delete(self, store: Store) -> None:
121+
async def test_delete(self, store: S) -> None:
53122
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
54123
assert await store.exists("foo/zarr.json")
55124
await store.delete("foo/zarr.json")
56125
assert not await store.exists("foo/zarr.json")
57126

58-
async def test_list(self, store: Store) -> None:
127+
async def test_list(self, store: S) -> None:
59128
assert [k async for k in store.list()] == []
60129
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
61130
keys = [k async for k in store.list()]
@@ -69,11 +138,11 @@ async def test_list(self, store: Store) -> None:
69138
f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little"))
70139
)
71140

72-
async def test_list_prefix(self, store: Store) -> None:
141+
async def test_list_prefix(self, store: S) -> None:
73142
# TODO: we currently don't use list_prefix anywhere
74-
pass
143+
raise NotImplementedError
75144

76-
async def test_list_dir(self, store: Store) -> None:
145+
async def test_list_dir(self, store: S) -> None:
77146
assert [k async for k in store.list_dir("")] == []
78147
assert [k async for k in store.list_dir("foo")] == []
79148
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))

0 commit comments

Comments
 (0)