55import asyncio
66import hashlib
77import json
8+ from logging import Logger
89import shutil
910import time
10- from collections .abc import Iterable
1111from contextlib import asynccontextmanager
1212from dataclasses import dataclass , field
1313from pathlib import Path
14- from typing import Any
14+ from typing import override
1515from urllib .parse import urlparse
1616
1717import httpx
1818import platformdirs
19- import snakemake_storage_plugin_http as http_base
20- from reretry import retry
19+ from reretry import retry # pyright: ignore[reportUnknownVariableType]
2120from snakemake_interface_common .exceptions import WorkflowError
2221from snakemake_interface_common .logging import get_logger
2322from snakemake_interface_common .plugin_registry .plugin import SettingsBase
3231)
3332from tqdm_loggable .auto import tqdm
3433
34+ from .cache import Cache
35+ from .monkeypatch import is_zenodo_url # noqa: F401 - applies monkeypatch on import
36+
3537logger = get_logger ()
3638
3739
3840class ReretryLoggerAdapter :
3941 """Adapter to make Snakemake's logger compatible with reretry's logging expectations."""
4042
41- def __init__ (self , snakemake_logger ):
43+ _logger : Logger
44+
45+ def __init__ (self , snakemake_logger : Logger ):
4246 self ._logger = snakemake_logger
4347
44- def warning (self , msg , * args , ** kwargs ):
48+ def warning (self , msg : str , * args , ** kwargs ): # pyright: ignore[reportUnknownParameterType, reportUnusedParameter, reportMissingParameterType]
4549 """
4650 Format message manually before passing to Snakemake logger.
4751
@@ -55,27 +59,6 @@ def warning(self, msg, *args, **kwargs):
5559 self ._logger .warning (msg )
5660
5761
58- def is_zenodo_url (url ):
59- parsed = urlparse (url )
60- return parsed .netloc in ("zenodo.org" , "sandbox.zenodo.org" ) and parsed .scheme in (
61- "http" ,
62- "https" ,
63- )
64-
65-
66- # Patch the original HTTP StorageProvider off zenodo urls, so that there is no conflict
67- orig_valid_query = http_base .StorageProvider .is_valid_query
68- http_base .StorageProvider .is_valid_query = classmethod (
69- lambda c , q : (
70- StorageQueryValidationResult (
71- query = q , valid = False , reason = "Deactivated in favour of cached_http"
72- )
73- if is_zenodo_url (q )
74- else orig_valid_query (q )
75- )
76- )
77-
78-
7962# Define settings for the Zenodo storage plugin
8063# NB: We derive from SettingsBase rather than StorageProviderSettingsBase to remove the
8164# unsupported max_requests_per_second option
@@ -115,14 +98,17 @@ class ZenodoFileMetadata:
11598
11699
117100class WrongChecksum (Exception ):
101+ observed : str
102+ expected : str
103+
118104 def __init__ (self , observed : str , expected : str ):
119105 self .observed = observed
120106 self .expected = expected
121107 super ().__init__ (f"Checksum mismatch: expected { expected } , got { observed } " )
122108
123109
124110retry_decorator = retry (
125- exceptions = (
111+ exceptions = ( # pyright: ignore[reportArgumentType]
126112 httpx .HTTPError ,
127113 TimeoutError ,
128114 OSError ,
@@ -132,21 +118,22 @@ def __init__(self, observed: str, expected: str):
132118 tries = 5 ,
133119 delay = 3 ,
134120 backoff = 2 ,
135- logger = ReretryLoggerAdapter (get_logger ()),
121+ logger = ReretryLoggerAdapter (get_logger ()), # pyright: ignore[reportArgumentType]
136122)
137123
138124
139125# Implementation of storage provider
140126class StorageProvider (StorageProviderBase ):
127+ settings : StorageProviderSettings
128+ cache : Cache | None
129+
141130 def __post_init__ (self ):
142131 super ().__post_init__ ()
143132
144- # Set up cache directory
145- if self .settings .cache :
146- self .cache_dir = Path (self .settings .cache ).expanduser ()
147- self .cache_dir .mkdir (exist_ok = True , parents = True )
148- else :
149- self .cache_dir = None
133+ # Set up cache
134+ self .cache = (
135+ Cache (cache_dir = Path (self .settings .cache )) if self .settings .cache else None
136+ )
150137
151138 # Initialize shared client for bounding connections and pipelining
152139 self ._client : httpx .AsyncClient | None = None
@@ -155,16 +142,20 @@ def __post_init__(self):
155142 # Cache for record metadata to avoid repeated API calls
156143 self ._record_cache : dict [str , dict [str , ZenodoFileMetadata ]] = {}
157144
145+ @override
158146 def use_rate_limiter (self ) -> bool :
159147 """Return False if no rate limiting is needed for this provider."""
160148 return False
161149
162- def rate_limiter_key (self , query : str , operation : Operation ) -> Any :
150+ @override
151+ def rate_limiter_key (self , query : str , operation : Operation ) -> str :
163152 raise NotImplementedError ()
164153
154+ @override
165155 def default_max_requests_per_second (self ) -> float :
166156 raise NotImplementedError ()
167157
158+ @override
168159 @classmethod
169160 def example_queries (cls ) -> list [ExampleQuery ]:
170161 """Return an example query with description for this storage provider."""
@@ -176,6 +167,7 @@ def example_queries(cls) -> list[ExampleQuery]:
176167 )
177168 ]
178169
170+ @override
179171 @classmethod
180172 def is_valid_query (cls , query : str ) -> StorageQueryValidationResult :
181173 """Only handle zenodo.org URLs"""
@@ -188,13 +180,11 @@ def is_valid_query(cls, query: str) -> StorageQueryValidationResult:
188180 reason = "Not a Zenodo URL (only zenodo.org URLs are handled by this plugin)" ,
189181 )
190182
183+ @override
191184 @classmethod
192185 def get_storage_object_cls (cls ):
193186 return StorageObject
194187
195- def list_objects (self , query : Any ) -> Iterable [str ]:
196- raise NotImplementedError ()
197-
198188 @asynccontextmanager
199189 async def client (self ):
200190 """
@@ -231,7 +221,7 @@ async def client(self):
231221 await self ._client .aclose ()
232222 self ._client = None
233223
234- def _get_rate_limit_wait_time (self , headers ) -> float | None :
224+ def _get_rate_limit_wait_time (self , headers : httpx . Headers ) -> float | None :
235225 """
236226 Calculate wait time based on rate limit headers.
237227
@@ -306,12 +296,12 @@ async def get_metadata(
306296 data = json .loads (content )
307297
308298 # Parse files array and build metadata dict
309- metadata = {}
299+ metadata : dict [ str , ZenodoFileMetadata ] = {}
310300 files = data .get ("files" , [])
311301 for file_info in files :
312- filename = file_info .get ("key" )
313- checksum = file_info .get ("checksum" )
314- size = file_info .get ("size" , 0 )
302+ filename : str | None = file_info .get ("key" )
303+ checksum : str | None = file_info .get ("checksum" )
304+ size : int = file_info .get ("size" , 0 )
315305
316306 if not filename :
317307 continue
@@ -326,13 +316,13 @@ async def get_metadata(
326316
327317# Implementation of storage object
328318class StorageObject (StorageObjectRead ):
319+ provider : StorageProvider # pyright: ignore[reportIncompatibleVariableOverride]
320+ record_id : str
321+ filename : str
322+ netloc : str
323+
329324 def __post_init__ (self ):
330325 super ().__post_init__ ()
331- if self .provider .cache_dir is not None :
332- self .query_path : Path = self .provider .cache_dir / self .local_suffix ()
333- self .query_path .parent .mkdir (exist_ok = True , parents = True )
334- else :
335- self .query_path = None
336326
337327 # Parse URL to extract record ID and filename
338328 # URL format: https://zenodo.org/records/{record_id}/files/{filename}
@@ -344,49 +334,56 @@ def __post_init__(self):
344334 if _records != "records" or _files != "files" :
345335 raise WorkflowError (
346336 f"Invalid Zenodo URL format: { self .query } . "
347- "Expected format: https://zenodo.org/records/{{record_id}}/files/{{filename}}"
337+ f "Expected format: https://zenodo.org/records/{{record_id}}/files/{{filename}}"
348338 )
349339
350340 self .record_id = record_id
351341 self .filename = filename
352342 self .netloc = parsed .netloc
353343
354- def local_suffix (self ):
355- parsed = urlparse (self .query )
344+ @override
345+ def local_suffix (self ) -> str :
346+ """Return the local suffix for this object (used by parent class)."""
347+ parsed = urlparse (str (self .query ))
356348 return f"{ parsed .netloc } { parsed .path } "
357349
350+ @override
358351 def get_inventory_parent (self ) -> str | None :
359352 """Return the parent directory of this object."""
360353 # this is optional and can be left as is
361354 return None
362355
356+ @override
363357 async def managed_exists (self ) -> bool :
364358 if self .provider .settings .skip_remote_checks :
365359 return True
366360
367- exists = self .query_path and self .query_path .exists ()
368- if exists :
369- return True
361+ if self .provider .cache :
362+ cached = self .provider .cache .get (str (self .query ))
363+ if cached is not None :
364+ return True
370365
371366 metadata = await self .provider .get_metadata (self .record_id , self .netloc )
372367 return self .filename in metadata
373368
369+ @override
374370 async def managed_mtime (self ) -> float :
375371 return 0
376372
373+ @override
377374 async def managed_size (self ) -> int :
378375 if self .provider .settings .skip_remote_checks :
379376 return 0
380377
381- exists = self .query_path and self .query_path .exists ()
382- if exists :
383- return self .query_path .stat ().st_size
378+ if self .provider .cache :
379+ cached = self .provider .cache .get (str (self .query ))
380+ if cached is not None :
381+ return cached .stat ().st_size
384382
385383 metadata = await self .provider .get_metadata (self .record_id , self .netloc )
386384 return metadata [self .filename ].size if self .filename in metadata else 0
387385
388- managed_local_footprint = managed_size
389-
386+ @override
390387 async def inventory (self , cache : IOCacheStorageInterface ) -> None :
391388 """
392389 Gather file metadata (existence, size) from cache or remote.
@@ -399,38 +396,44 @@ async def inventory(self, cache: IOCacheStorageInterface) -> None:
399396
400397 if self .provider .settings .skip_remote_checks :
401398 cache .exists_in_storage [key ] = True
402- cache .mtime [key ] = 0
399+ cache .mtime [key ] = Mtime ( storage = 0 )
403400 cache .size [key ] = 0
404401 return
405402
406- exists = self .query_path and self .query_path .exists ()
407- if exists :
408- cache .exists_in_storage [key ] = exists
409- cache .mtime [key ] = 0
410- cache .size [key ] = self .query_path .stat ().st_size
411- return
403+ if self .provider .cache :
404+ cached = self .provider .cache .get (str (self .query ))
405+ if cached is not None :
406+ cache .exists_in_storage [key ] = True
407+ cache .mtime [key ] = Mtime (storage = 0 )
408+ cache .size [key ] = cached .stat ().st_size
409+ return
412410
413411 metadata = await self .provider .get_metadata (self .record_id , self .netloc )
414412 exists = self .filename in metadata
415413 cache .exists_in_storage [key ] = exists
416414 cache .mtime [key ] = Mtime (storage = 0 )
417415 cache .size [key ] = metadata [self .filename ].size if exists else 0
418416
417+ @override
419418 def cleanup (self ):
420419 """Nothing to cleanup"""
421420 pass
422421
423- def exists (self ):
422+ @override
423+ def exists (self ) -> bool :
424424 raise NotImplementedError ()
425425
426- def size (self ):
426+ @override
427+ def size (self ) -> int :
427428 raise NotImplementedError ()
428429
429- def mtime (self ):
430+ @override
431+ def mtime (self ) -> float :
430432 raise NotImplementedError ()
431433
432- def retrieve_object (self ):
433- return NotImplementedError ()
434+ @override
435+ def retrieve_object (self ) -> None :
436+ raise NotImplementedError ()
434437
435438 async def verify_checksum (self , path : Path ) -> None :
436439 """
@@ -453,7 +456,7 @@ async def verify_checksum(self, path: Path) -> None:
453456 digest , checksum_expected = checksum .split (":" , maxsplit = 1 )
454457
455458 # Compute checksum asynchronously (hashlib releases GIL)
456- def compute_hash (digest = digest ):
459+ def compute_hash (digest : str = digest ):
457460 with open (path , "rb" ) as f :
458461 return hashlib .file_digest (f , digest ).hexdigest ().lower ()
459462
@@ -469,13 +472,14 @@ async def managed_retrieve(self):
469472 local_path .parent .mkdir (parents = True , exist_ok = True )
470473
471474 # If already in cache, just copy
472- if self .query_path and self .query_path .exists ():
473- # Verify cached file checksum
474- logger .info (
475- f"Retrieved { self .filename } of zenodo record { self .record_id } from cache"
476- )
477- shutil .copy2 (self .query_path , local_path )
478- return
475+ if self .provider .cache :
476+ cached = self .provider .cache .get (str (self .query ))
477+ if cached is not None :
478+ logger .info (
479+ f"Retrieved { self .filename } of zenodo record { self .record_id } from cache"
480+ )
481+ shutil .copy2 (cached , local_path )
482+ return
479483
480484 try :
481485 # Download from Zenodo using a get request, rate limit errors are detected and
@@ -505,10 +509,8 @@ async def managed_retrieve(self):
505509 await self .verify_checksum (local_path )
506510
507511 # Copy to cache after successful verification
508- if self .query_path :
509- self .query_path .parent .mkdir (parents = True , exist_ok = True )
510- shutil .copy2 (local_path , self .query_path )
511- logger .info (f"Cached { self .filename } to { self .provider .cache_dir } " )
512+ if self .provider .cache :
513+ self .provider .cache .put (str (self .query ), local_path )
512514
513515 except :
514516 if local_path .exists ():
0 commit comments