|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +from collections.abc import AsyncGenerator |
3 | 4 | from typing import TYPE_CHECKING, Any
|
4 | 5 |
|
| 6 | +import fsspec |
| 7 | + |
5 | 8 | from zarr.abc.store import Store
|
6 |
| -from zarr.buffer import Buffer, BufferPrototype |
| 9 | +from zarr.buffer import Buffer, BufferPrototype, default_buffer_prototype |
7 | 10 | from zarr.common import OpenMode
|
8 | 11 | from zarr.store.core import _dereference_path
|
9 | 12 |
|
10 | 13 | if TYPE_CHECKING:
|
11 | 14 | from fsspec.asyn import AsyncFileSystem
|
12 | 15 | from upath import UPath
|
13 | 16 |
|
| 17 | + from zarr.buffer import Buffer |
| 18 | + from zarr.common import BytesLike |
| 19 | + |
14 | 20 |
|
15 | 21 | class RemoteStore(Store):
|
| 22 | + # based on FSSpec |
16 | 23 | supports_writes: bool = True
|
17 | 24 | supports_partial_writes: bool = False
|
18 | 25 | supports_listing: bool = True
|
19 | 26 |
|
20 |
| - root: UPath |
| 27 | + _fs: AsyncFileSystem |
| 28 | + path: str |
| 29 | + allowed_exceptions: tuple[type[Exception], ...] |
21 | 30 |
|
22 | 31 | def __init__(
|
23 |
| - self, url: UPath | str, *, mode: OpenMode = "r", **storage_options: dict[str, Any] |
| 32 | + self, |
| 33 | + url: UPath | str, |
| 34 | + mode: OpenMode = "r", |
| 35 | + allowed_exceptions: tuple[type[Exception], ...] = ( |
| 36 | + FileNotFoundError, |
| 37 | + IsADirectoryError, |
| 38 | + NotADirectoryError, |
| 39 | + ), |
| 40 | + **storage_options: Any, |
24 | 41 | ):
|
25 |
| - import fsspec |
26 |
| - from upath import UPath |
| 42 | + """ |
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + url: root of the datastore. In fsspec notation, this is usually like "protocol://path/to". |
| 46 | + Can also be a upath.UPath instance/ |
| 47 | + allowed_exceptions: when fetching data, these cases will be deemed to correspond to missing |
| 48 | + keys, rather than some other IO failure |
| 49 | + storage_options: passed on to fsspec to make the filesystem instance. If url is a UPath, |
| 50 | + this must not be used. |
| 51 | + """ |
27 | 52 |
|
28 | 53 | super().__init__(mode=mode)
|
29 | 54 |
|
30 | 55 | if isinstance(url, str):
|
31 |
| - self.root = UPath(url, **storage_options) |
| 56 | + self._fs, self.path = fsspec.url_to_fs(url, **storage_options) |
| 57 | + elif hasattr(url, "protocol") and hasattr(url, "fs"): |
| 58 | + # is UPath-like - but without importing |
| 59 | + if storage_options: |
| 60 | + raise ValueError( |
| 61 | + "If constructed with a UPath object, no additional " |
| 62 | + "storage_options are allowed" |
| 63 | + ) |
| 64 | + self.path = url.path |
| 65 | + self._fs = url._fs |
32 | 66 | else:
|
33 |
| - assert ( |
34 |
| - len(storage_options) == 0 |
35 |
| - ), "If constructed with a UPath object, no additional storage_options are allowed." |
36 |
| - self.root = url.rstrip("/") |
37 |
| - |
| 67 | + raise ValueError("URL not understood, %s", url) |
| 68 | + self.allowed_exceptions = allowed_exceptions |
38 | 69 | # test instantiate file system
|
39 |
| - fs, _ = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs) |
40 |
| - assert fs.__class__.async_impl, "FileSystem needs to support async operations." |
| 70 | + if not self._fs.async_impl: |
| 71 | + raise TypeError("FileSystem needs to support async operations") |
41 | 72 |
|
42 | 73 | def __str__(self) -> str:
|
43 |
| - return str(self.root) |
| 74 | + return f"Remote fsspec store: {type(self._fs).__name__} , {self.path}" |
44 | 75 |
|
45 | 76 | def __repr__(self) -> str:
|
46 |
| - return f"RemoteStore({str(self)!r})" |
47 |
| - |
48 |
| - def _make_fs(self) -> tuple[AsyncFileSystem, str]: |
49 |
| - import fsspec |
50 |
| - |
51 |
| - storage_options = self.root._kwargs.copy() |
52 |
| - storage_options.pop("_url", None) |
53 |
| - fs, root = fsspec.core.url_to_fs(str(self.root), asynchronous=True, **self.root._kwargs) |
54 |
| - assert fs.__class__.async_impl, "FileSystem needs to support async operations." |
55 |
| - return fs, root |
| 77 | + return f"<RemoteStore({type(self._fs).__name__} , {self.path})>" |
56 | 78 |
|
57 | 79 | async def get(
|
58 | 80 | self,
|
59 | 81 | key: str,
|
60 |
| - prototype: BufferPrototype, |
| 82 | + prototype: BufferPrototype = default_buffer_prototype, |
61 | 83 | byte_range: tuple[int | None, int | None] | None = None,
|
62 | 84 | ) -> Buffer | None:
|
63 |
| - assert isinstance(key, str) |
64 |
| - fs, root = self._make_fs() |
65 |
| - path = _dereference_path(root, key) |
| 85 | + path = _dereference_path(self.path, key) |
66 | 86 |
|
67 | 87 | try:
|
68 |
| - value: Buffer | None = await ( |
69 |
| - fs._cat_file(path, start=byte_range[0], end=byte_range[1]) |
70 |
| - if byte_range |
71 |
| - else fs._cat_file(path) |
| 88 | + if byte_range: |
| 89 | + # fsspec uses start/end, not start/length |
| 90 | + start, length = byte_range |
| 91 | + if start is not None and length is not None: |
| 92 | + end = start + length |
| 93 | + elif length is not None: |
| 94 | + end = length |
| 95 | + else: |
| 96 | + end = None |
| 97 | + value: Buffer = prototype.buffer.from_bytes( |
| 98 | + await ( |
| 99 | + self._fs._cat_file(path, start=byte_range[0], end=end) |
| 100 | + if byte_range |
| 101 | + else self._fs._cat_file(path) |
| 102 | + ) |
72 | 103 | )
|
73 |
| - except (FileNotFoundError, IsADirectoryError, NotADirectoryError): |
74 |
| - return None |
| 104 | + return value |
75 | 105 |
|
76 |
| - return value |
| 106 | + except self.allowed_exceptions: |
| 107 | + return None |
| 108 | + except OSError as e: |
| 109 | + if "not satisfiable" in str(e): |
| 110 | + # this is an s3-specific condition we probably don't want to leak |
| 111 | + return prototype.buffer.from_bytes(b"") |
| 112 | + raise |
77 | 113 |
|
78 |
| - async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None: |
| 114 | + async def set( |
| 115 | + self, |
| 116 | + key: str, |
| 117 | + value: Buffer, |
| 118 | + byte_range: tuple[int, int] | None = None, |
| 119 | + ) -> None: |
79 | 120 | self._check_writable()
|
80 |
| - assert isinstance(key, str) |
81 |
| - fs, root = self._make_fs() |
82 |
| - path = _dereference_path(root, key) |
83 |
| - |
| 121 | + path = _dereference_path(self.path, key) |
84 | 122 | # write data
|
85 | 123 | if byte_range:
|
86 |
| - with fs._open(path, "r+b") as f: |
87 |
| - f.seek(byte_range[0]) |
88 |
| - f.write(value) |
89 |
| - else: |
90 |
| - await fs._pipe_file(path, value) |
| 124 | + raise NotImplementedError |
| 125 | + await self._fs._pipe_file(path, value.to_bytes()) |
91 | 126 |
|
92 | 127 | async def delete(self, key: str) -> None:
|
93 | 128 | self._check_writable()
|
94 |
| - fs, root = self._make_fs() |
95 |
| - path = _dereference_path(root, key) |
96 |
| - if await fs._exists(path): |
97 |
| - await fs._rm(path) |
| 129 | + path = _dereference_path(self.path, key) |
| 130 | + try: |
| 131 | + await self._fs._rm(path) |
| 132 | + except FileNotFoundError: |
| 133 | + pass |
| 134 | + except self.allowed_exceptions: |
| 135 | + pass |
98 | 136 |
|
99 | 137 | async def exists(self, key: str) -> bool:
|
100 |
| - fs, root = self._make_fs() |
101 |
| - path = _dereference_path(root, key) |
102 |
| - exists: bool = await fs._exists(path) |
| 138 | + path = _dereference_path(self.path, key) |
| 139 | + exists: bool = await self._fs._exists(path) |
103 | 140 | return exists
|
| 141 | + |
| 142 | + async def get_partial_values( |
| 143 | + self, |
| 144 | + prototype: BufferPrototype, |
| 145 | + key_ranges: list[tuple[str, tuple[int | None, int | None]]], |
| 146 | + ) -> list[Buffer | None]: |
| 147 | + if key_ranges: |
| 148 | + paths, starts, stops = zip( |
| 149 | + *( |
| 150 | + ( |
| 151 | + _dereference_path(self.path, k[0]), |
| 152 | + k[1][0], |
| 153 | + ((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None, |
| 154 | + ) |
| 155 | + for k in key_ranges |
| 156 | + ), |
| 157 | + strict=False, |
| 158 | + ) |
| 159 | + else: |
| 160 | + return [] |
| 161 | + # TODO: expectations for exceptions or missing keys? |
| 162 | + res = await self._fs._cat_ranges(list(paths), starts, stops, on_error="return") |
| 163 | + # the following is an s3-specific condition we probably don't want to leak |
| 164 | + res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res] |
| 165 | + for r in res: |
| 166 | + if isinstance(r, Exception) and not isinstance(r, self.allowed_exceptions): |
| 167 | + raise r |
| 168 | + |
| 169 | + return [None if isinstance(r, Exception) else prototype.buffer.from_bytes(r) for r in res] |
| 170 | + |
| 171 | + async def set_partial_values(self, key_start_values: list[tuple[str, int, BytesLike]]) -> None: |
| 172 | + raise NotImplementedError |
| 173 | + |
| 174 | + async def list(self) -> AsyncGenerator[str, None]: |
| 175 | + allfiles = await self._fs._find(self.path, detail=False, withdirs=False) |
| 176 | + for onefile in (a.replace(self.path + "/", "") for a in allfiles): |
| 177 | + yield onefile |
| 178 | + |
| 179 | + async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: |
| 180 | + prefix = f"{self.path}/{prefix.rstrip('/')}" |
| 181 | + try: |
| 182 | + allfiles = await self._fs._ls(prefix, detail=False) |
| 183 | + except FileNotFoundError: |
| 184 | + return |
| 185 | + for onefile in (a.replace(prefix + "/", "") for a in allfiles): |
| 186 | + yield onefile |
| 187 | + |
| 188 | + async def list_prefix(self, prefix: str) -> AsyncGenerator[str, None]: |
| 189 | + for onefile in await self._fs._ls(prefix, detail=False): |
| 190 | + yield onefile |
0 commit comments