Skip to content

Commit

Permalink
[V3] Expand store tests (zarr-developers#1900)
Browse files Browse the repository at this point in the history
* Fill in some test methods with NotImplementedError to force implementations to implement them; make StoreTests generic w.r.t. the store class being tested; update store.get abc to match actual type signature

* remove auto_mkdir from LocalStore; add set and get methods to StoreTests class

* fix: use from_bytes method on buffer

* fix: use Buffer instead of bytes for store tests

* docstrings, add some Nones to test_get_partial_values; normalize function signatures
  • Loading branch information
d-v-b authored May 28, 2024
1 parent b1f4c50 commit fc7fa4f
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 818 deletions.
4 changes: 2 additions & 2 deletions src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
class Store(ABC):
@abstractmethod
async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
"""Retrieve the value associated with a given key.
Expand All @@ -26,7 +26,7 @@ async def get(

@abstractmethod
async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
"""Retrieve possibly partial values from given key_ranges.
Expand Down
24 changes: 24 additions & 0 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,27 @@ def make_store_path(store_like: StoreLike) -> StorePath:
elif isinstance(store_like, str):
return StorePath(LocalStore(Path(store_like)))
raise TypeError


def _normalize_interval_index(
data: Buffer, interval: None | tuple[int | None, int | None]
) -> tuple[int, int]:
"""
Convert an implicit interval into an explicit start and length
"""
if interval is None:
start = 0
length = len(data)
else:
maybe_start, maybe_len = interval
if maybe_start is None:
start = 0
else:
start = maybe_start

if maybe_len is None:
length = len(data) - start
else:
length = maybe_len

return (start, length)
16 changes: 6 additions & 10 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from zarr.common import concurrent_map, to_thread


def _get(path: Path, byte_range: tuple[int, int | None] | None) -> Buffer:
def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer:
"""
Fetch a contiguous region of bytes from a file.
Expand Down Expand Up @@ -51,10 +51,8 @@ def _put(
path: Path,
value: Buffer,
start: int | None = None,
auto_mkdir: bool = True,
) -> int | None:
if auto_mkdir:
path.parent.mkdir(parents=True, exist_ok=True)
path.parent.mkdir(parents=True, exist_ok=True)
if start is not None:
with path.open("r+b") as f:
f.seek(start)
Expand All @@ -70,15 +68,13 @@ class LocalStore(Store):
supports_listing: bool = True

root: Path
auto_mkdir: bool

def __init__(self, root: Path | str, auto_mkdir: bool = True):
def __init__(self, root: Path | str):
if isinstance(root, str):
root = Path(root)
assert isinstance(root, Path)

self.root = root
self.auto_mkdir = auto_mkdir

def __str__(self) -> str:
return f"file://{self.root}"
Expand All @@ -90,7 +86,7 @@ def __eq__(self, other: object) -> bool:
return isinstance(other, type(self)) and self.root == other.root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
path = self.root / key
Expand All @@ -101,7 +97,7 @@ async def get(
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
"""
Read byte ranges from multiple keys.
Expand All @@ -128,7 +124,7 @@ async def set(self, key: str, value: Buffer) -> None:
if not isinstance(value, Buffer):
raise TypeError("LocalStore.set(): `value` must a Buffer instance")
path = self.root / key
await to_thread(_put, path, value, auto_mkdir=self.auto_mkdir)
await to_thread(_put, path, value)

async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
args = []
Expand Down
10 changes: 5 additions & 5 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map
from zarr.store.core import _normalize_interval_index


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
Expand All @@ -26,19 +27,18 @@ def __repr__(self) -> str:
return f"MemoryStore({str(self)!r})"

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
try:
value = self._store_dict[key]
if byte_range is not None:
value = value[byte_range[0] : byte_range[1]]
return value
start, length = _normalize_interval_index(value, byte_range)
return value[start : start + length]
except KeyError:
return None

async def get_partial_values(
self, key_ranges: list[tuple[str, tuple[int, int]]]
self, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> list[Buffer | None]:
vals = await concurrent_map(key_ranges, self.get, limit=None)
return vals
Expand Down
2 changes: 1 addition & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def _make_fs(self) -> tuple[AsyncFileSystem, str]:
return fs, root

async def get(
self, key: str, byte_range: tuple[int, int | None] | None = None
self, key: str, byte_range: tuple[int | None, int | None] | None = None
) -> Buffer | None:
assert isinstance(key, str)
fs, root = self._make_fs()
Expand Down
125 changes: 97 additions & 28 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,130 @@
from typing import Generic, TypeVar

import pytest

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.store.core import _normalize_interval_index
from zarr.testing.utils import assert_bytes_equal

S = TypeVar("S", bound=Store)


class StoreTests(Generic[S]):
store_cls: type[S]

class StoreTests:
store_cls: type[Store]
def set(self, store: S, key: str, value: Buffer) -> None:
"""
Insert a value into a storage backend, with a specific key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""
raise NotImplementedError

def get(self, store: S, key: str) -> Buffer:
"""
Retrieve a value from a storage backend, by key.
This should not not use any store methods. Bypassing the store methods allows them to be
tested.
"""

raise NotImplementedError

@pytest.fixture(scope="function")
def store(self) -> Store:
return self.store_cls()

def test_store_type(self, store: Store) -> None:
def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_repr(self, store: Store) -> None:
assert repr(store)
def test_store_repr(self, store: S) -> None:
raise NotImplementedError

def test_store_supports_writes(self, store: S) -> None:
raise NotImplementedError

def test_store_capabilities(self, store: Store) -> None:
assert store.supports_writes
assert store.supports_partial_writes
assert store.supports_listing
def test_store_supports_partial_writes(self, store: S) -> None:
raise NotImplementedError

def test_store_supports_listing(self, store: S) -> None:
raise NotImplementedError

@pytest.mark.parametrize("key", ["c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_set_get_bytes_roundtrip(self, store: Store, key: str, data: bytes) -> None:
await store.set(key, Buffer.from_bytes(data))
assert_bytes_equal(await store.get(key), data)

@pytest.mark.parametrize("key", ["foo/c/0"])
@pytest.mark.parametrize("byte_range", (None, (0, None), (1, None), (1, 2), (None, 1)))
async def test_get(
self, store: S, key: str, data: bytes, byte_range: None | tuple[int | None, int | None]
) -> None:
"""
Ensure that data can be read from the store using the store.get method.
"""
data_buf = Buffer.from_bytes(data)
self.set(store, key, data_buf)
observed = await store.get(key, byte_range=byte_range)
start, length = _normalize_interval_index(data_buf, interval=byte_range)
expected = data_buf[start : start + length]
assert_bytes_equal(observed, expected)

@pytest.mark.parametrize("key", ["zarr.json", "c/0", "foo/c/0.0", "foo/0/0"])
@pytest.mark.parametrize("data", [b"\x01\x02\x03\x04", b""])
async def test_get_partial_values(self, store: Store, key: str, data: bytes) -> None:
async def test_set(self, store: S, key: str, data: bytes) -> None:
"""
Ensure that data can be written to the store using the store.set method.
"""
data_buf = Buffer.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
assert_bytes_equal(observed, data_buf)

@pytest.mark.parametrize(
"key_ranges",
(
[],
[("zarr.json", (0, 1))],
[("c/0", (0, 1)), ("zarr.json", (0, None))],
[("c/0/0", (0, 1)), ("c/0/1", (None, 2)), ("c/0/2", (0, 3))],
),
)
async def test_get_partial_values(
self, store: S, key_ranges: list[tuple[str, tuple[int | None, int | None]]]
) -> None:
# put all of the data
await store.set(key, Buffer.from_bytes(data))
for key, _ in key_ranges:
self.set(store, key, Buffer.from_bytes(bytes(key, encoding="utf-8")))

# read back just part of it
vals = await store.get_partial_values([(key, (0, 2))])
assert_bytes_equal(vals[0], data[0:2])
observed_maybe = await store.get_partial_values(key_ranges=key_ranges)

observed: list[Buffer] = []
expected: list[Buffer] = []

for obs in observed_maybe:
assert obs is not None
observed.append(obs)

for idx in range(len(observed)):
key, byte_range = key_ranges[idx]
result = await store.get(key, byte_range=byte_range)
assert result is not None
expected.append(result)

# read back multiple parts of it at once
vals = await store.get_partial_values([(key, (0, 2)), (key, (2, 4))])
assert_bytes_equal(vals[0], data[0:2])
assert_bytes_equal(vals[1], data[2:4])
assert all(
obs.to_bytes() == exp.to_bytes() for obs, exp in zip(observed, expected, strict=True)
)

async def test_exists(self, store: Store) -> None:
async def test_exists(self, store: S) -> None:
assert not await store.exists("foo")
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
assert await store.exists("foo/zarr.json")

async def test_delete(self, store: Store) -> None:
async def test_delete(self, store: S) -> None:
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
assert await store.exists("foo/zarr.json")
await store.delete("foo/zarr.json")
assert not await store.exists("foo/zarr.json")

async def test_list(self, store: Store) -> None:
async def test_list(self, store: S) -> None:
assert [k async for k in store.list()] == []
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
keys = [k async for k in store.list()]
Expand All @@ -69,11 +138,11 @@ async def test_list(self, store: Store) -> None:
f"foo/c/{i}", Buffer.from_bytes(i.to_bytes(length=3, byteorder="little"))
)

async def test_list_prefix(self, store: Store) -> None:
async def test_list_prefix(self, store: S) -> None:
# TODO: we currently don't use list_prefix anywhere
pass
raise NotImplementedError

async def test_list_dir(self, store: Store) -> None:
async def test_list_dir(self, store: S) -> None:
assert [k async for k in store.list_dir("")] == []
assert [k async for k in store.list_dir("foo")] == []
await store.set("foo/zarr.json", Buffer.from_bytes(b"bar"))
Expand Down
Loading

0 comments on commit fc7fa4f

Please sign in to comment.