Skip to content

Commit 58636b3

Browse files
committed
feat: add cache module and typing
1 parent ee58f77 commit 58636b3

File tree

5 files changed

+192
-86
lines changed

5 files changed

+192
-86
lines changed

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,10 @@ dev = [
6565
"pytest-asyncio>=1.2.0",
6666
"snakemake>=9.13.4",
6767
]
68+
69+
[tool.basedpyright]
70+
reportImplicitStringConcatenation = false
71+
ignore = ["tests", "src/snakemake_storage_plugin_cached_http/monkeypatch.py"]
72+
reportUnusedCallResult = false
73+
reportAny = false
74+
reportMissingTypeStubs = false

src/snakemake_storage_plugin_cached_http/__init__.py

Lines changed: 86 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,18 @@
55
import asyncio
66
import hashlib
77
import json
8+
from logging import Logger
89
import shutil
910
import time
10-
from collections.abc import Iterable
1111
from contextlib import asynccontextmanager
1212
from dataclasses import dataclass, field
1313
from pathlib import Path
14-
from typing import Any
14+
from typing import override
1515
from urllib.parse import urlparse
1616

1717
import httpx
1818
import platformdirs
19-
import snakemake_storage_plugin_http as http_base
20-
from reretry import retry
19+
from reretry import retry # pyright: ignore[reportUnknownVariableType]
2120
from snakemake_interface_common.exceptions import WorkflowError
2221
from snakemake_interface_common.logging import get_logger
2322
from snakemake_interface_common.plugin_registry.plugin import SettingsBase
@@ -32,16 +31,21 @@
3231
)
3332
from tqdm_loggable.auto import tqdm
3433

34+
from .cache import Cache
35+
from .monkeypatch import is_zenodo_url # noqa: F401 - applies monkeypatch on import
36+
3537
logger = get_logger()
3638

3739

3840
class ReretryLoggerAdapter:
3941
"""Adapter to make Snakemake's logger compatible with reretry's logging expectations."""
4042

41-
def __init__(self, snakemake_logger):
43+
_logger: Logger
44+
45+
def __init__(self, snakemake_logger: Logger):
4246
self._logger = snakemake_logger
4347

44-
def warning(self, msg, *args, **kwargs):
48+
def warning(self, msg: str, *args, **kwargs): # pyright: ignore[reportUnknownParameterType, reportUnusedParameter, reportMissingParameterType]
4549
"""
4650
Format message manually before passing to Snakemake logger.
4751
@@ -55,27 +59,6 @@ def warning(self, msg, *args, **kwargs):
5559
self._logger.warning(msg)
5660

5761

58-
def is_zenodo_url(url):
59-
parsed = urlparse(url)
60-
return parsed.netloc in ("zenodo.org", "sandbox.zenodo.org") and parsed.scheme in (
61-
"http",
62-
"https",
63-
)
64-
65-
66-
# Patch the original HTTP StorageProvider off zenodo urls, so that there is no conflict
67-
orig_valid_query = http_base.StorageProvider.is_valid_query
68-
http_base.StorageProvider.is_valid_query = classmethod(
69-
lambda c, q: (
70-
StorageQueryValidationResult(
71-
query=q, valid=False, reason="Deactivated in favour of cached_http"
72-
)
73-
if is_zenodo_url(q)
74-
else orig_valid_query(q)
75-
)
76-
)
77-
78-
7962
# Define settings for the Zenodo storage plugin
8063
# NB: We derive from SettingsBase rather than StorageProviderSettingsBase to remove the
8164
# unsupported max_requests_per_second option
@@ -115,14 +98,17 @@ class ZenodoFileMetadata:
11598

11699

117100
class WrongChecksum(Exception):
101+
observed: str
102+
expected: str
103+
118104
def __init__(self, observed: str, expected: str):
119105
self.observed = observed
120106
self.expected = expected
121107
super().__init__(f"Checksum mismatch: expected {expected}, got {observed}")
122108

123109

124110
retry_decorator = retry(
125-
exceptions=(
111+
exceptions=( # pyright: ignore[reportArgumentType]
126112
httpx.HTTPError,
127113
TimeoutError,
128114
OSError,
@@ -132,21 +118,22 @@ def __init__(self, observed: str, expected: str):
132118
tries=5,
133119
delay=3,
134120
backoff=2,
135-
logger=ReretryLoggerAdapter(get_logger()),
121+
logger=ReretryLoggerAdapter(get_logger()), # pyright: ignore[reportArgumentType]
136122
)
137123

138124

139125
# Implementation of storage provider
140126
class StorageProvider(StorageProviderBase):
127+
settings: StorageProviderSettings
128+
cache: Cache | None
129+
141130
def __post_init__(self):
142131
super().__post_init__()
143132

144-
# Set up cache directory
145-
if self.settings.cache:
146-
self.cache_dir = Path(self.settings.cache).expanduser()
147-
self.cache_dir.mkdir(exist_ok=True, parents=True)
148-
else:
149-
self.cache_dir = None
133+
# Set up cache
134+
self.cache = (
135+
Cache(cache_dir=Path(self.settings.cache)) if self.settings.cache else None
136+
)
150137

151138
# Initialize shared client for bounding connections and pipelining
152139
self._client: httpx.AsyncClient | None = None
@@ -155,16 +142,20 @@ def __post_init__(self):
155142
# Cache for record metadata to avoid repeated API calls
156143
self._record_cache: dict[str, dict[str, ZenodoFileMetadata]] = {}
157144

145+
@override
158146
def use_rate_limiter(self) -> bool:
159147
"""Return False if no rate limiting is needed for this provider."""
160148
return False
161149

162-
def rate_limiter_key(self, query: str, operation: Operation) -> Any:
150+
@override
151+
def rate_limiter_key(self, query: str, operation: Operation) -> str:
163152
raise NotImplementedError()
164153

154+
@override
165155
def default_max_requests_per_second(self) -> float:
166156
raise NotImplementedError()
167157

158+
@override
168159
@classmethod
169160
def example_queries(cls) -> list[ExampleQuery]:
170161
"""Return an example query with description for this storage provider."""
@@ -176,6 +167,7 @@ def example_queries(cls) -> list[ExampleQuery]:
176167
)
177168
]
178169

170+
@override
179171
@classmethod
180172
def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
181173
"""Only handle zenodo.org URLs"""
@@ -188,13 +180,11 @@ def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
188180
reason="Not a Zenodo URL (only zenodo.org URLs are handled by this plugin)",
189181
)
190182

183+
@override
191184
@classmethod
192185
def get_storage_object_cls(cls):
193186
return StorageObject
194187

195-
def list_objects(self, query: Any) -> Iterable[str]:
196-
raise NotImplementedError()
197-
198188
@asynccontextmanager
199189
async def client(self):
200190
"""
@@ -231,7 +221,7 @@ async def client(self):
231221
await self._client.aclose()
232222
self._client = None
233223

234-
def _get_rate_limit_wait_time(self, headers) -> float | None:
224+
def _get_rate_limit_wait_time(self, headers: httpx.Headers) -> float | None:
235225
"""
236226
Calculate wait time based on rate limit headers.
237227
@@ -306,12 +296,12 @@ async def get_metadata(
306296
data = json.loads(content)
307297

308298
# Parse files array and build metadata dict
309-
metadata = {}
299+
metadata: dict[str, ZenodoFileMetadata] = {}
310300
files = data.get("files", [])
311301
for file_info in files:
312-
filename = file_info.get("key")
313-
checksum = file_info.get("checksum")
314-
size = file_info.get("size", 0)
302+
filename: str | None = file_info.get("key")
303+
checksum: str | None = file_info.get("checksum")
304+
size: int = file_info.get("size", 0)
315305

316306
if not filename:
317307
continue
@@ -326,13 +316,13 @@ async def get_metadata(
326316

327317
# Implementation of storage object
328318
class StorageObject(StorageObjectRead):
319+
provider: StorageProvider # pyright: ignore[reportIncompatibleVariableOverride]
320+
record_id: str
321+
filename: str
322+
netloc: str
323+
329324
def __post_init__(self):
330325
super().__post_init__()
331-
if self.provider.cache_dir is not None:
332-
self.query_path: Path = self.provider.cache_dir / self.local_suffix()
333-
self.query_path.parent.mkdir(exist_ok=True, parents=True)
334-
else:
335-
self.query_path = None
336326

337327
# Parse URL to extract record ID and filename
338328
# URL format: https://zenodo.org/records/{record_id}/files/{filename}
@@ -344,49 +334,56 @@ def __post_init__(self):
344334
if _records != "records" or _files != "files":
345335
raise WorkflowError(
346336
f"Invalid Zenodo URL format: {self.query}. "
347-
"Expected format: https://zenodo.org/records/{{record_id}}/files/{{filename}}"
337+
f"Expected format: https://zenodo.org/records/{{record_id}}/files/{{filename}}"
348338
)
349339

350340
self.record_id = record_id
351341
self.filename = filename
352342
self.netloc = parsed.netloc
353343

354-
def local_suffix(self):
355-
parsed = urlparse(self.query)
344+
@override
345+
def local_suffix(self) -> str:
346+
"""Return the local suffix for this object (used by parent class)."""
347+
parsed = urlparse(str(self.query))
356348
return f"{parsed.netloc}{parsed.path}"
357349

350+
@override
358351
def get_inventory_parent(self) -> str | None:
359352
"""Return the parent directory of this object."""
360353
# this is optional and can be left as is
361354
return None
362355

356+
@override
363357
async def managed_exists(self) -> bool:
364358
if self.provider.settings.skip_remote_checks:
365359
return True
366360

367-
exists = self.query_path and self.query_path.exists()
368-
if exists:
369-
return True
361+
if self.provider.cache:
362+
cached = self.provider.cache.get(str(self.query))
363+
if cached is not None:
364+
return True
370365

371366
metadata = await self.provider.get_metadata(self.record_id, self.netloc)
372367
return self.filename in metadata
373368

369+
@override
374370
async def managed_mtime(self) -> float:
375371
return 0
376372

373+
@override
377374
async def managed_size(self) -> int:
378375
if self.provider.settings.skip_remote_checks:
379376
return 0
380377

381-
exists = self.query_path and self.query_path.exists()
382-
if exists:
383-
return self.query_path.stat().st_size
378+
if self.provider.cache:
379+
cached = self.provider.cache.get(str(self.query))
380+
if cached is not None:
381+
return cached.stat().st_size
384382

385383
metadata = await self.provider.get_metadata(self.record_id, self.netloc)
386384
return metadata[self.filename].size if self.filename in metadata else 0
387385

388-
managed_local_footprint = managed_size
389-
386+
@override
390387
async def inventory(self, cache: IOCacheStorageInterface) -> None:
391388
"""
392389
Gather file metadata (existence, size) from cache or remote.
@@ -399,38 +396,44 @@ async def inventory(self, cache: IOCacheStorageInterface) -> None:
399396

400397
if self.provider.settings.skip_remote_checks:
401398
cache.exists_in_storage[key] = True
402-
cache.mtime[key] = 0
399+
cache.mtime[key] = Mtime(storage=0)
403400
cache.size[key] = 0
404401
return
405402

406-
exists = self.query_path and self.query_path.exists()
407-
if exists:
408-
cache.exists_in_storage[key] = exists
409-
cache.mtime[key] = 0
410-
cache.size[key] = self.query_path.stat().st_size
411-
return
403+
if self.provider.cache:
404+
cached = self.provider.cache.get(str(self.query))
405+
if cached is not None:
406+
cache.exists_in_storage[key] = True
407+
cache.mtime[key] = Mtime(storage=0)
408+
cache.size[key] = cached.stat().st_size
409+
return
412410

413411
metadata = await self.provider.get_metadata(self.record_id, self.netloc)
414412
exists = self.filename in metadata
415413
cache.exists_in_storage[key] = exists
416414
cache.mtime[key] = Mtime(storage=0)
417415
cache.size[key] = metadata[self.filename].size if exists else 0
418416

417+
@override
419418
def cleanup(self):
420419
"""Nothing to cleanup"""
421420
pass
422421

423-
def exists(self):
422+
@override
423+
def exists(self) -> bool:
424424
raise NotImplementedError()
425425

426-
def size(self):
426+
@override
427+
def size(self) -> int:
427428
raise NotImplementedError()
428429

429-
def mtime(self):
430+
@override
431+
def mtime(self) -> float:
430432
raise NotImplementedError()
431433

432-
def retrieve_object(self):
433-
return NotImplementedError()
434+
@override
435+
def retrieve_object(self) -> None:
436+
raise NotImplementedError()
434437

435438
async def verify_checksum(self, path: Path) -> None:
436439
"""
@@ -453,7 +456,7 @@ async def verify_checksum(self, path: Path) -> None:
453456
digest, checksum_expected = checksum.split(":", maxsplit=1)
454457

455458
# Compute checksum asynchronously (hashlib releases GIL)
456-
def compute_hash(digest=digest):
459+
def compute_hash(digest: str = digest):
457460
with open(path, "rb") as f:
458461
return hashlib.file_digest(f, digest).hexdigest().lower()
459462

@@ -469,13 +472,14 @@ async def managed_retrieve(self):
469472
local_path.parent.mkdir(parents=True, exist_ok=True)
470473

471474
# If already in cache, just copy
472-
if self.query_path and self.query_path.exists():
473-
# Verify cached file checksum
474-
logger.info(
475-
f"Retrieved {self.filename} of zenodo record {self.record_id} from cache"
476-
)
477-
shutil.copy2(self.query_path, local_path)
478-
return
475+
if self.provider.cache:
476+
cached = self.provider.cache.get(str(self.query))
477+
if cached is not None:
478+
logger.info(
479+
f"Retrieved {self.filename} of zenodo record {self.record_id} from cache"
480+
)
481+
shutil.copy2(cached, local_path)
482+
return
479483

480484
try:
481485
# Download from Zenodo using a get request, rate limit errors are detected and
@@ -505,10 +509,8 @@ async def managed_retrieve(self):
505509
await self.verify_checksum(local_path)
506510

507511
# Copy to cache after successful verification
508-
if self.query_path:
509-
self.query_path.parent.mkdir(parents=True, exist_ok=True)
510-
shutil.copy2(local_path, self.query_path)
511-
logger.info(f"Cached {self.filename} to {self.provider.cache_dir}")
512+
if self.provider.cache:
513+
self.provider.cache.put(str(self.query), local_path)
512514

513515
except:
514516
if local_path.exists():

0 commit comments

Comments
 (0)