Skip to content

Fix typing in store tests #3097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,11 @@ module = [
"tests.test_store.test_fsspec",
"tests.test_store.test_memory",
"tests.test_codecs.test_codecs",
"tests.test_store.test_core",
"tests.test_store.test_logging",
"tests.test_store.test_object",
"tests.test_store.test_stateful",
"tests.test_store.test_wrapper",
]
strict = false

Expand All @@ -373,11 +378,6 @@ strict = false
[[tool.mypy.overrides]]
module = [
"tests.test_metadata.*",
"tests.test_store.test_core",
"tests.test_store.test_logging",
"tests.test_store.test_object",
"tests.test_store.test_stateful",
"tests.test_store.test_wrapper",
"tests.test_group",
"tests.test_indexing",
"tests.test_properties",
Expand Down
23 changes: 12 additions & 11 deletions src/zarr/storage/_obstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypedDict
from typing import TYPE_CHECKING, Any, Generic, TypedDict, TypeVar

from zarr.abc.store import (
ByteRequest,
Expand All @@ -27,6 +27,7 @@
from zarr.core.buffer import Buffer, BufferPrototype
from zarr.core.common import BytesLike


__all__ = ["ObjectStore"]

_ALLOWED_EXCEPTIONS: tuple[type[Exception], ...] = (
Expand All @@ -35,8 +36,10 @@
NotADirectoryError,
)

T_Store = TypeVar("T_Store", bound="_UpstreamObjectStore")


class ObjectStore(Store):
class ObjectStore(Store, Generic[T_Store]):
"""
Store that uses obstore for fast read/write from AWS, GCP, Azure.

Expand All @@ -53,19 +56,17 @@
raise an issue with any comments/concerns about the store.
"""

store: _UpstreamObjectStore
store: T_Store
"""The underlying obstore instance."""

def __eq__(self, value: object) -> bool:
if not isinstance(value, ObjectStore):
return False

if not self.read_only == value.read_only:
return False

return self.store == value.store
return bool(

Check warning on line 63 in src/zarr/storage/_obstore.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/storage/_obstore.py#L63

Added line #L63 was not covered by tests
isinstance(value, ObjectStore)
and self.read_only == value.read_only
and self.store == value.store
)

def __init__(self, store: _UpstreamObjectStore, *, read_only: bool = False) -> None:
def __init__(self, store: T_Store, *, read_only: bool = False) -> None:
if not store.__class__.__module__.startswith("obstore"):
raise TypeError(f"expected ObjectStore class, got {store!r}")
super().__init__(read_only=read_only)
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def test_store_supports_partial_writes(self, store: S) -> None: ...
def test_store_supports_listing(self, store: S) -> None: ...

@pytest.fixture
def open_kwargs(self, store_kwargs: dict[str, Any]) -> dict[str, Any]:
def open_kwargs(
self, store_kwargs: dict[str, Any], *args: Any, **kwargs: Any
) -> dict[str, Any]:
"""
Kwargs for opening a store.

By default uses the result of the store_kwargs fixture,
but can be overridden to be different.
"""
return store_kwargs

@pytest.fixture
Expand Down
23 changes: 13 additions & 10 deletions tests/test_store/test_core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tempfile
from collections.abc import Callable, Coroutine
from pathlib import Path

import pytest
Expand All @@ -22,7 +23,7 @@
@pytest.mark.parametrize("write_group", [True, False])
@pytest.mark.parametrize("zarr_format", [2, 3])
async def test_contains_group(
local_store, path: str, write_group: bool, zarr_format: ZarrFormat
local_store: LocalStore, path: str, write_group: bool, zarr_format: ZarrFormat
) -> None:
"""
Test that the contains_group method correctly reports the existence of a group.
Expand All @@ -38,7 +39,7 @@ async def test_contains_group(
@pytest.mark.parametrize("write_array", [True, False])
@pytest.mark.parametrize("zarr_format", [2, 3])
async def test_contains_array(
local_store, path: str, write_array: bool, zarr_format: ZarrFormat
local_store: LocalStore, path: str, write_array: bool, zarr_format: ZarrFormat
) -> None:
"""
Test that the contains array method correctly reports the existence of an array.
Expand All @@ -51,13 +52,15 @@ async def test_contains_array(


@pytest.mark.parametrize("func", [contains_array, contains_group])
async def test_contains_invalid_format_raises(local_store, func: callable) -> None:
async def test_contains_invalid_format_raises(
local_store: LocalStore, func: Callable[[StorePath, ZarrFormat], Coroutine[None, None, bool]]
) -> None:
"""
Test contains_group and contains_array raise errors for invalid zarr_formats
"""
store_path = StorePath(local_store)
with pytest.raises(ValueError):
assert await func(store_path, zarr_format="3.0")
assert await func(store_path, "3.0") # type: ignore[arg-type]


@pytest.mark.parametrize("path", [None, "", "bar"])
Expand Down Expand Up @@ -110,12 +113,12 @@ async def test_make_store_path_store_path(


@pytest.mark.parametrize("modes", [(True, "w"), (False, "x")])
async def test_store_path_invalid_mode_raises(tmpdir: LEGACY_PATH, modes: tuple) -> None:
async def test_store_path_invalid_mode_raises(tmpdir: LEGACY_PATH, modes: tuple[bool, str]) -> None:
"""
Test that ValueErrors are raise for invalid mode.
"""
with pytest.raises(ValueError):
await StorePath.open(LocalStore(str(tmpdir), read_only=modes[0]), path=None, mode=modes[1])
await StorePath.open(LocalStore(str(tmpdir), read_only=modes[0]), path="", mode=modes[1]) # type: ignore[arg-type]


async def test_make_store_path_invalid() -> None:
Expand All @@ -126,7 +129,7 @@ async def test_make_store_path_invalid() -> None:
await make_store_path(1) # type: ignore[arg-type]


async def test_make_store_path_fsspec(monkeypatch) -> None:
async def test_make_store_path_fsspec() -> None:
pytest.importorskip("fsspec")
pytest.importorskip("requests")
pytest.importorskip("aiohttp")
Expand Down Expand Up @@ -175,12 +178,12 @@ def test_normalize_path_upath() -> None:
assert normalize_path(upath.UPath("foo/bar")) == "foo/bar"


def test_normalize_path_none():
def test_normalize_path_none() -> None:
assert normalize_path(None) == ""


@pytest.mark.parametrize("path", [".", ".."])
def test_normalize_path_invalid(path: str):
def test_normalize_path_invalid(path: str) -> None:
with pytest.raises(ValueError):
normalize_path(path)

Expand Down Expand Up @@ -221,7 +224,7 @@ def test_invalid(paths: tuple[str, str]) -> None:
_normalize_paths(paths)


def test_normalize_path_keys():
def test_normalize_path_keys() -> None:
"""
Test that ``_normalize_path_keys`` just applies the normalize_path function to each key of its
input
Expand Down
42 changes: 22 additions & 20 deletions tests/test_store/test_logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest

Expand All @@ -11,60 +11,62 @@
from zarr.testing.store import StoreTests

if TYPE_CHECKING:
from _pytest.compat import LEGACY_PATH
from pathlib import Path

from zarr.abc.store import Store


class TestLoggingStore(StoreTests[LoggingStore, cpu.Buffer]):
class TestLoggingStore(StoreTests[LoggingStore[LocalStore], cpu.Buffer]):
store_cls = LoggingStore
buffer_cls = cpu.Buffer

async def get(self, store: LoggingStore, key: str) -> Buffer:
async def get(self, store: LoggingStore[LocalStore], key: str) -> Buffer:
return self.buffer_cls.from_bytes((store._store.root / key).read_bytes())

async def set(self, store: LoggingStore, key: str, value: Buffer) -> None:
async def set(self, store: LoggingStore[LocalStore], key: str, value: Buffer) -> None:
parent = (store._store.root / key).parent
if not parent.exists():
parent.mkdir(parents=True)
(store._store.root / key).write_bytes(value.to_bytes())

@pytest.fixture
def store_kwargs(self, tmpdir: LEGACY_PATH) -> dict[str, str]:
return {"store": LocalStore(str(tmpdir)), "log_level": "DEBUG"}
def store_kwargs(self, tmp_path: Path) -> dict[str, str | LocalStore]:
return {"store": LocalStore(str(tmp_path)), "log_level": "DEBUG"}

@pytest.fixture
def open_kwargs(self, tmpdir) -> dict[str, str]:
return {"store_cls": LocalStore, "root": str(tmpdir), "log_level": "DEBUG"}
def open_kwargs(self, store_kwargs: dict[str, Any], tmp_path: Path) -> dict[str, Any]:
return {"store_cls": LocalStore, "root": str(tmp_path), "log_level": "DEBUG"}

@pytest.fixture
def store(self, store_kwargs: str | dict[str, Buffer] | None) -> LoggingStore:
async def store(self, store_kwargs: dict[str, Any]) -> LoggingStore[LocalStore]:
return self.store_cls(**store_kwargs)

def test_store_supports_writes(self, store: LoggingStore) -> None:
def test_store_supports_writes(self, store: LoggingStore[LocalStore]) -> None:
assert store.supports_writes

def test_store_supports_partial_writes(self, store: LoggingStore) -> None:
def test_store_supports_partial_writes(self, store: LoggingStore[LocalStore]) -> None:
assert store.supports_partial_writes

def test_store_supports_listing(self, store: LoggingStore) -> None:
def test_store_supports_listing(self, store: LoggingStore[LocalStore]) -> None:
assert store.supports_listing

def test_store_repr(self, store: LoggingStore) -> None:
def test_store_repr(self, store: LoggingStore[LocalStore]) -> None:
assert f"{store!r}" == f"LoggingStore(LocalStore, 'file://{store._store.root.as_posix()}')"

def test_store_str(self, store: LoggingStore) -> None:
def test_store_str(self, store: LoggingStore[LocalStore]) -> None:
assert str(store) == f"logging-file://{store._store.root.as_posix()}"

async def test_default_handler(self, local_store, capsys) -> None:
async def test_default_handler(
self, local_store: LocalStore, capsys: pytest.CaptureFixture[str]
) -> None:
# Store and then remove existing handlers to enter default handler code path
handlers = logging.getLogger().handlers[:]
for h in handlers:
logging.getLogger().removeHandler(h)
# Test logs are sent to stdout
wrapped = LoggingStore(store=local_store)
buffer = default_buffer_prototype().buffer
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04"))
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value]
assert res is None
captured = capsys.readouterr()
assert len(captured) == 2
Expand All @@ -74,7 +76,7 @@ async def test_default_handler(self, local_store, capsys) -> None:
for h in handlers:
logging.getLogger().addHandler(h)

def test_is_open_setter_raises(self, store: LoggingStore) -> None:
def test_is_open_setter_raises(self, store: LoggingStore[LocalStore]) -> None:
"Test that a user cannot change `_is_open` without opening the underlying store."
with pytest.raises(
NotImplementedError, match="LoggingStore must be opened via the `_open` method"
Expand All @@ -83,12 +85,12 @@ def test_is_open_setter_raises(self, store: LoggingStore) -> None:


@pytest.mark.parametrize("store", ["local", "memory", "zip"], indirect=["store"])
async def test_logging_store(store: Store, caplog) -> None:
async def test_logging_store(store: Store, caplog: pytest.LogCaptureFixture) -> None:
wrapped = LoggingStore(store=store, log_level="DEBUG")
buffer = default_buffer_prototype().buffer

caplog.clear()
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04"))
res = await wrapped.set("foo/bar/c/0", buffer.from_bytes(b"\x01\x02\x03\x04")) # type: ignore[func-returns-value]
assert res is None
assert len(caplog.record_tuples) == 2
for tup in caplog.record_tuples:
Expand Down
34 changes: 16 additions & 18 deletions tests/test_store/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest

obstore = pytest.importorskip("obstore")
from pathlib import Path

import pytest
from hypothesis.stateful import (
run_state_machine_as_test,
Expand All @@ -16,47 +18,43 @@
from zarr.testing.store import StoreTests


class TestObjectStore(StoreTests[ObjectStore, cpu.Buffer]):
class TestObjectStore(StoreTests[ObjectStore[LocalStore], cpu.Buffer]):
store_cls = ObjectStore
buffer_cls = cpu.Buffer

@pytest.fixture
def store_kwargs(self, tmpdir) -> dict[str, Any]:
store = LocalStore(prefix=tmpdir)
def store_kwargs(self, tmp_path: Path) -> dict[str, Any]:
store = LocalStore(prefix=str(tmp_path))
return {"store": store, "read_only": False}

@pytest.fixture
def store(self, store_kwargs: dict[str, str | bool]) -> ObjectStore:
async def store(self, store_kwargs: dict[str, Any]) -> ObjectStore[LocalStore]:
return self.store_cls(**store_kwargs)

async def get(self, store: ObjectStore, key: str) -> Buffer:
assert isinstance(store.store, LocalStore)
async def get(self, store: ObjectStore[LocalStore], key: str) -> Buffer:
new_local_store = LocalStore(prefix=store.store.prefix)
return self.buffer_cls.from_bytes(obstore.get(new_local_store, key).bytes())

async def set(self, store: ObjectStore, key: str, value: Buffer) -> None:
assert isinstance(store.store, LocalStore)
async def set(self, store: ObjectStore[LocalStore], key: str, value: Buffer) -> None:
new_local_store = LocalStore(prefix=store.store.prefix)
obstore.put(new_local_store, key, value.to_bytes())

def test_store_repr(self, store: ObjectStore) -> None:
def test_store_repr(self, store: ObjectStore[LocalStore]) -> None:
from fnmatch import fnmatch

pattern = "ObjectStore(object_store://LocalStore(*))"
assert fnmatch(f"{store!r}", pattern)

def test_store_supports_writes(self, store: ObjectStore) -> None:
def test_store_supports_writes(self, store: ObjectStore[LocalStore]) -> None:
assert store.supports_writes

async def test_store_supports_partial_writes(self, store: ObjectStore) -> None:
def test_store_supports_partial_writes(self, store: ObjectStore[LocalStore]) -> None:
assert not store.supports_partial_writes
with pytest.raises(NotImplementedError):
await store.set_partial_values([("foo", 0, b"\x01\x02\x03\x04")])

def test_store_supports_listing(self, store: ObjectStore) -> None:
def test_store_supports_listing(self, store: ObjectStore[LocalStore]) -> None:
assert store.supports_listing

def test_store_equal(self, store: ObjectStore) -> None:
def test_store_equal(self, store: ObjectStore[LocalStore]) -> None:
"""Test store equality"""
# Test equality against a different instance type
assert store != 0
Expand All @@ -73,14 +71,14 @@ def test_store_equal(self, store: ObjectStore) -> None:
def test_store_init_raises(self) -> None:
"""Test __init__ raises appropriate error for improper store type"""
with pytest.raises(TypeError):
ObjectStore("path/to/store")
ObjectStore("path/to/store") # type: ignore[type-var]


@pytest.mark.slow_hypothesis
def test_zarr_hierarchy():
def test_zarr_hierarchy() -> None:
sync_store = ObjectStore(MemoryStore())

def mk_test_instance_sync() -> ZarrHierarchyStateMachine:
return ZarrHierarchyStateMachine(sync_store)

run_state_machine_as_test(mk_test_instance_sync)
run_state_machine_as_test(mk_test_instance_sync) # type ignore[no-untyped-call]
Loading