Skip to content

Commit

Permalink
[v3] Feature: Store open mode (zarr-developers#1911)
Browse files Browse the repository at this point in the history
* wip

* feature(store): set open mode on store initialization
  • Loading branch information
jhamman authored May 29, 2024
1 parent fc7fa4f commit ef15e20
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 33 deletions.
27 changes: 26 additions & 1 deletion src/zarr/abc/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,31 @@
from typing import Protocol, runtime_checkable

from zarr.buffer import Buffer
from zarr.common import BytesLike
from zarr.common import BytesLike, OpenMode


class Store(ABC):
_mode: OpenMode

def __init__(self, mode: OpenMode = "r"):
if mode not in ("r", "r+", "w", "w-", "a"):
raise ValueError("mode must be one of 'r', 'r+', 'w', 'w-', 'a'")
self._mode = mode

@property
def mode(self) -> OpenMode:
"""Access mode of the store."""
return self._mode

@property
def writeable(self) -> bool:
"""Is the store writeable?"""
return self.mode in ("a", "w", "w-")

def _check_writable(self) -> None:
if not self.writeable:
raise ValueError("store mode does not support writing")

@abstractmethod
async def get(
self, key: str, byte_range: tuple[int | None, int | None] | None = None
Expand Down Expand Up @@ -147,6 +168,10 @@ def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:
"""
...

def close(self) -> None: # noqa: B027
"""Close the store."""
pass


@runtime_checkable
class ByteGetter(Protocol):
Expand Down
1 change: 1 addition & 0 deletions src/zarr/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Selection = slice | SliceSelection
ZarrFormat = Literal[2, 3]
JSON = None | str | int | float | Enum | dict[str, "JSON"] | list["JSON"] | tuple["JSON", ...]
OpenMode = Literal["r", "r+", "a", "w", "w-"]


def product(tup: ChunkCoords) -> int:
Expand Down
10 changes: 8 additions & 2 deletions src/zarr/store/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.local import LocalStore


Expand Down Expand Up @@ -60,13 +61,18 @@ def __eq__(self, other: Any) -> bool:
StoreLike = Store | StorePath | Path | str


def make_store_path(store_like: StoreLike) -> StorePath:
def make_store_path(store_like: StoreLike, *, mode: OpenMode | None = None) -> StorePath:
if isinstance(store_like, StorePath):
if mode is not None:
assert mode == store_like.store.mode
return store_like
elif isinstance(store_like, Store):
if mode is not None:
assert mode == store_like.mode
return StorePath(store_like)
elif isinstance(store_like, str):
return StorePath(LocalStore(Path(store_like)))
assert mode is not None
return StorePath(LocalStore(Path(store_like), mode=mode))
raise TypeError


Expand Down
8 changes: 6 additions & 2 deletions src/zarr/store/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map, to_thread
from zarr.common import OpenMode, concurrent_map, to_thread


def _get(path: Path, byte_range: tuple[int | None, int | None] | None) -> Buffer:
Expand Down Expand Up @@ -69,7 +69,8 @@ class LocalStore(Store):

root: Path

def __init__(self, root: Path | str):
def __init__(self, root: Path | str, *, mode: OpenMode = "r"):
super().__init__(mode=mode)
if isinstance(root, str):
root = Path(root)
assert isinstance(root, Path)
Expand Down Expand Up @@ -117,6 +118,7 @@ async def get_partial_values(
return await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def set(self, key: str, value: Buffer) -> None:
self._check_writable()
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
Expand All @@ -127,6 +129,7 @@ async def set(self, key: str, value: Buffer) -> None:
await to_thread(_put, path, value)

async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]]) -> None:
self._check_writable()
args = []
for key, start, value in key_start_values:
assert isinstance(key, str)
Expand All @@ -138,6 +141,7 @@ async def set_partial_values(self, key_start_values: list[tuple[str, int, bytes]
await concurrent_map(args, to_thread, limit=None) # TODO: fix limit

async def delete(self, key: str) -> None:
self._check_writable()
path = self.root / key
if path.is_dir(): # TODO: support deleting directories? shutil.rmtree?
shutil.rmtree(path)
Expand Down
9 changes: 7 additions & 2 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import concurrent_map
from zarr.common import OpenMode, concurrent_map
from zarr.store.core import _normalize_interval_index


Expand All @@ -17,7 +17,10 @@ class MemoryStore(Store):

_store_dict: MutableMapping[str, Buffer]

def __init__(self, store_dict: MutableMapping[str, Buffer] | None = None):
def __init__(
self, store_dict: MutableMapping[str, Buffer] | None = None, *, mode: OpenMode = "r"
):
super().__init__(mode=mode)
self._store_dict = store_dict or {}

def __str__(self) -> str:
Expand Down Expand Up @@ -47,6 +50,7 @@ async def exists(self, key: str) -> bool:
return key in self._store_dict

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
if isinstance(value, bytes | bytearray):
# TODO: to support the v2 tests, we convert bytes to Buffer here
Expand All @@ -62,6 +66,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
self._store_dict[key] = value

async def delete(self, key: str) -> None:
self._check_writable()
try:
del self._store_dict[key]
except KeyError:
Expand Down
10 changes: 9 additions & 1 deletion src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from zarr.abc.store import Store
from zarr.buffer import Buffer
from zarr.common import OpenMode
from zarr.store.core import _dereference_path

if TYPE_CHECKING:
Expand All @@ -18,17 +19,22 @@ class RemoteStore(Store):

root: UPath

def __init__(self, url: UPath | str, **storage_options: dict[str, Any]):
def __init__(
self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any]
):
import fsspec
from upath import UPath

super().__init__(mode=mode)

if isinstance(url, str):
self.root = UPath(url, **storage_options)
else:
assert (
len(storage_options) == 0
), "If constructed with a UPath object, no additional storage_options are allowed."
self.root = url.rstrip("/")

# test instantiate file system
fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs)
assert fs.__class__.async_impl, "FileSystem needs to support async operations."
Expand Down Expand Up @@ -67,6 +73,7 @@ async def get(
return value

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
fs, root = self._make_fs()
path = _dereference_path(root, key)
Expand All @@ -80,6 +87,7 @@ async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None
await fs._pipe_file(path, value)

async def delete(self, key: str) -> None:
self._check_writable()
fs, root = self._make_fs()
path = _dereference_path(root, key)
if await fs._exists(path):
Expand Down
37 changes: 34 additions & 3 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Generic, TypeVar
from typing import Any, Generic, TypeVar

import pytest

Expand Down Expand Up @@ -31,13 +31,43 @@ def get(self, store: S, key: str) -> Buffer:
raise NotImplementedError

@pytest.fixture(scope="function")
def store(self) -> Store:
return self.store_cls()
def store_kwargs(self) -> dict[str, Any]:
return {"mode": "w"}

@pytest.fixture(scope="function")
def store(self, store_kwargs: dict[str, Any]) -> Store:
return self.store_cls(**store_kwargs)

def test_store_type(self, store: S) -> None:
assert isinstance(store, Store)
assert isinstance(store, self.store_cls)

def test_store_mode(self, store: S, store_kwargs: dict[str, Any]) -> None:
assert store.mode == "w", store.mode
assert store.writeable

with pytest.raises(AttributeError):
store.mode = "w" # type: ignore

# read-only
kwargs = {**store_kwargs, "mode": "r"}
read_store = self.store_cls(**kwargs)
assert read_store.mode == "r", read_store.mode
assert not read_store.writeable

async def test_not_writable_store_raises(self, store_kwargs: dict[str, Any]) -> None:
kwargs = {**store_kwargs, "mode": "r"}
store = self.store_cls(**kwargs)
assert not store.writeable

# set
with pytest.raises(ValueError):
await store.set("foo", Buffer.from_bytes(b"bar"))

# delete
with pytest.raises(ValueError):
await store.delete("foo")

def test_store_repr(self, store: S) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -72,6 +102,7 @@ async def test_set(self, store: S, key: str, data: bytes) -> None:
"""
Ensure that data can be written to the store using the store.set method.
"""
assert store.writeable
data_buf = Buffer.from_bytes(data)
await store.set(key, data_buf)
observed = self.get(store, key)
Expand Down
14 changes: 7 additions & 7 deletions tests/v3/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def parse_store(
store: Literal["local", "memory", "remote"], path: str
) -> LocalStore | MemoryStore | RemoteStore:
if store == "local":
return LocalStore(path)
return LocalStore(path, mode="w")
if store == "memory":
return MemoryStore()
return MemoryStore(mode="w")
if store == "remote":
return RemoteStore()
return RemoteStore(mode="w")
raise AssertionError


Expand All @@ -38,24 +38,24 @@ def path_type(request):
# todo: harmonize this with local_store fixture
@pytest.fixture
def store_path(tmpdir):
store = LocalStore(str(tmpdir))
store = LocalStore(str(tmpdir), mode="w")
p = StorePath(store)
return p


@pytest.fixture(scope="function")
def local_store(tmpdir):
return LocalStore(str(tmpdir))
return LocalStore(str(tmpdir), mode="w")


@pytest.fixture(scope="function")
def remote_store():
return RemoteStore()
return RemoteStore(mode="w")


@pytest.fixture(scope="function")
def memory_store():
return MemoryStore()
return MemoryStore(mode="w")


@pytest.fixture(scope="function")
Expand Down
2 changes: 1 addition & 1 deletion tests/v3/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def set(self, value: np.ndarray):

@pytest.fixture
def store() -> Iterator[Store]:
yield StorePath(MemoryStore())
yield StorePath(MemoryStore(mode="w"))


@pytest.fixture
Expand Down
29 changes: 16 additions & 13 deletions tests/v3/test_store.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from collections.abc import MutableMapping
from typing import Any

import pytest

Expand All @@ -10,7 +10,6 @@
from zarr.testing.store import StoreTests


@pytest.mark.parametrize("store_dict", (None, {}))
class TestMemoryStore(StoreTests[MemoryStore]):
store_cls = MemoryStore

Expand All @@ -20,21 +19,25 @@ def set(self, store: MemoryStore, key: str, value: Buffer) -> None:
def get(self, store: MemoryStore, key: str) -> Buffer:
return store._store_dict[key]

@pytest.fixture(scope="function", params=[None, {}])
def store_kwargs(self, request) -> dict[str, Any]:
return {"store_dict": request.param, "mode": "w"}

@pytest.fixture(scope="function")
def store(self, store_dict: MutableMapping[str, Buffer] | None):
return MemoryStore(store_dict=store_dict)
def store(self, store_kwargs: dict[str, Any]) -> MemoryStore:
return self.store_cls(**store_kwargs)

def test_store_repr(self, store: MemoryStore) -> None:
assert str(store) == f"memory://{id(store._store_dict)}"

def test_store_supports_writes(self, store: MemoryStore) -> None:
assert True
assert store.supports_writes

def test_store_supports_listing(self, store: MemoryStore) -> None:
assert True
assert store.supports_listing

def test_store_supports_partial_writes(self, store: MemoryStore) -> None:
assert True
assert store.supports_partial_writes

def test_list_prefix(self, store: MemoryStore) -> None:
assert True
Expand All @@ -52,21 +55,21 @@ def set(self, store: LocalStore, key: str, value: Buffer) -> None:
parent.mkdir(parents=True)
(store.root / key).write_bytes(value.to_bytes())

@pytest.fixture(scope="function")
def store(self, tmpdir) -> LocalStore:
return self.store_cls(str(tmpdir))
@pytest.fixture
def store_kwargs(self, tmpdir) -> dict[str, str]:
return {"root": str(tmpdir), "mode": "w"}

def test_store_repr(self, store: LocalStore) -> None:
assert str(store) == f"file://{store.root!s}"

def test_store_supports_writes(self, store: LocalStore) -> None:
assert True
assert store.supports_writes

def test_store_supports_partial_writes(self, store: LocalStore) -> None:
assert True
assert store.supports_partial_writes

def test_store_supports_listing(self, store: LocalStore) -> None:
assert True
assert store.supports_listing

def test_list_prefix(self, store: LocalStore) -> None:
assert True
Loading

0 comments on commit ef15e20

Please sign in to comment.