Skip to content

Commit

Permalink
Add CacheMapper to map from remote URL to local cached basename (#1296)
Browse files Browse the repository at this point in the history
* Add CacheMapper to map from remote URL to local cached basename

* Raise exception if 'fn' not in cached metadata

* Fix tests on Windows
  • Loading branch information
ianthomas23 authored Jul 26, 2023
1 parent 285094f commit aacc5f2
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 29 deletions.
57 changes: 57 additions & 0 deletions fsspec/implementations/cache_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from __future__ import annotations

import abc
import hashlib
import os
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any


class AbstractCacheMapper(abc.ABC):
"""Abstract super-class for mappers from remote URLs to local cached
basenames.
"""

@abc.abstractmethod
def __call__(self, path: str) -> str:
...

def __eq__(self, other: Any) -> bool:
# Identity only depends on class. When derived classes have attributes
# they will need to be included.
return isinstance(other, type(self))

def __hash__(self) -> int:
# Identity only depends on class. When derived classes have attributes
# they will need to be included.
return hash(type(self))


class BasenameCacheMapper(AbstractCacheMapper):
"""Cache mapper that uses the basename of the remote URL.
Different paths with the same basename will therefore have the same cached
basename.
"""

def __call__(self, path: str) -> str:
return os.path.basename(path)


class HashCacheMapper(AbstractCacheMapper):
"""Cache mapper that uses a hash of the remote URL."""

def __call__(self, path: str) -> str:
return hashlib.sha256(path.encode()).hexdigest()


def create_cache_mapper(same_names: bool) -> AbstractCacheMapper:
"""Factory method to create cache mapper for backward compatibility with
``CachingFileSystem`` constructor using ``same_names`` kwarg.
"""
if same_names:
return BasenameCacheMapper()
else:
return HashCacheMapper()
49 changes: 21 additions & 28 deletions fsspec/implementations/cached.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from __future__ import annotations

import contextlib
import hashlib
import inspect
import logging
import os
import pickle
import tempfile
import time
from shutil import rmtree
from typing import ClassVar
from typing import Any, ClassVar

from fsspec import AbstractFileSystem, filesystem
from fsspec.callbacks import _DEFAULT_CALLBACK
from fsspec.compression import compr
from fsspec.core import BaseCache, MMapCache
from fsspec.exceptions import BlocksizeMismatchError
from fsspec.implementations.cache_mapper import create_cache_mapper
from fsspec.spec import AbstractBufferedFile
from fsspec.utils import infer_compression

Expand Down Expand Up @@ -115,9 +115,7 @@ def __init__(
self.check_files = check_files
self.expiry = expiry_time
self.compression = compression
# TODO: same_names should allow for variable prefix, not only
# to keep the basename
self.same_names = same_names
self._mapper = create_cache_mapper(same_names)
self.target_protocol = (
target_protocol
if isinstance(target_protocol, str)
Expand Down Expand Up @@ -255,11 +253,12 @@ def clear_expired_cache(self, expiry_time=None):

for path, detail in self.cached_files[-1].copy().items():
if time.time() - detail["time"] > expiry_time:
if self.same_names:
basename = os.path.basename(detail["original"])
fn = os.path.join(self.storage[-1], basename)
else:
fn = os.path.join(self.storage[-1], detail["fn"])
fn = detail.get("fn", "")
if not fn:
raise RuntimeError(
f"Cache metadata does not contain 'fn' for {path}"
)
fn = os.path.join(self.storage[-1], fn)
if os.path.exists(fn):
os.remove(fn)
self.cached_files[-1].pop(path)
Expand Down Expand Up @@ -339,7 +338,7 @@ def _open(
# TODO: action where partial file exists in read-only cache
logger.debug("Opening partially cached copy of %s" % path)
else:
hash = self.hash_name(path, self.same_names)
hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
blocks = set()
detail = {
Expand Down Expand Up @@ -385,8 +384,10 @@ def _open(
self.save_cache()
return f

def hash_name(self, path, same_name):
return hash_name(path, same_name=same_name)
def hash_name(self, path: str, *args: Any) -> str:
# Kept for backward compatibility with downstream libraries.
# Ignores extra arguments, previously same_name boolean.
return self._mapper(path)

def close_and_update(self, f, close):
"""Called when a file is closing, so store the set of blocks"""
Expand Down Expand Up @@ -488,7 +489,7 @@ def __eq__(self, other):
and self.check_files == other.check_files
and self.expiry == other.expiry
and self.compression == other.compression
and self.same_names == other.same_names
and self._mapper == other._mapper
and self.target_protocol == other.target_protocol
)

Expand All @@ -501,7 +502,7 @@ def __hash__(self):
^ hash(self.check_files)
^ hash(self.expiry)
^ hash(self.compression)
^ hash(self.same_names)
^ hash(self._mapper)
^ hash(self.target_protocol)
)

Expand Down Expand Up @@ -546,7 +547,7 @@ def open_many(self, open_files):
details = [self._check_file(sp) for sp in paths]
downpath = [p for p, d in zip(paths, details) if not d]
downfn0 = [
os.path.join(self.storage[-1], self.hash_name(p, self.same_names))
os.path.join(self.storage[-1], self._mapper(p))
for p, d in zip(paths, details)
] # keep these path names for opening later
downfn = [fn for fn, d in zip(downfn0, details) if not d]
Expand All @@ -558,7 +559,7 @@ def open_many(self, open_files):
newdetail = [
{
"original": path,
"fn": self.hash_name(path, self.same_names),
"fn": self._mapper(path),
"blocks": True,
"time": time.time(),
"uid": self.fs.ukey(path),
Expand Down Expand Up @@ -590,7 +591,7 @@ def commit_many(self, open_files):
pass

def _make_local_details(self, path):
hash = self.hash_name(path, self.same_names)
hash = self._mapper(path)
fn = os.path.join(self.storage[-1], hash)
detail = {
"original": path,
Expand Down Expand Up @@ -731,7 +732,7 @@ def __init__(self, **kwargs):

def _check_file(self, path):
self._check_cache()
sha = self.hash_name(path, self.same_names)
sha = self._mapper(path)
for storage in self.storage:
fn = os.path.join(storage, sha)
if os.path.exists(fn):
Expand All @@ -752,7 +753,7 @@ def _open(self, path, mode="rb", **kwargs):
if fn:
return open(fn, mode)

sha = self.hash_name(path, self.same_names)
sha = self._mapper(path)
fn = os.path.join(self.storage[-1], sha)
logger.debug("Copying %s to local cache" % path)
kwargs["mode"] = mode
Expand Down Expand Up @@ -838,14 +839,6 @@ def __getattr__(self, item):
return getattr(self.fh, item)


def hash_name(path, same_name):
if same_name:
hash = os.path.basename(path)
else:
hash = hashlib.sha256(path.encode()).hexdigest()
return hash


@contextlib.contextmanager
def atomic_write(path, mode="wb"):
"""
Expand Down
79 changes: 78 additions & 1 deletion fsspec/implementations/tests/test_cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import fsspec
from fsspec.compression import compr
from fsspec.exceptions import BlocksizeMismatchError
from fsspec.implementations.cache_mapper import create_cache_mapper
from fsspec.implementations.cached import CachingFileSystem, LocalTempFile
from fsspec.implementations.local import make_path_posix

from .test_ftp import FTPFileSystem

Expand All @@ -32,6 +34,61 @@ def local_filecache():
return data, original_file, cache_location, fs


def test_mapper():
mapper0 = create_cache_mapper(True)
assert mapper0("/somedir/somefile") == "somefile"
assert mapper0("/otherdir/somefile") == "somefile"

mapper1 = create_cache_mapper(False)
assert (
mapper1("/somedir/somefile")
== "67a6956e5a5f95231263f03758c1fd9254fdb1c564d311674cec56b0372d2056"
)
assert (
mapper1("/otherdir/somefile")
== "f043dee01ab9b752c7f2ecaeb1a5e1b2d872018e2d0a1a26c43835ebf34e7d3e"
)

assert mapper0 != mapper1
assert create_cache_mapper(True) == mapper0
assert create_cache_mapper(False) == mapper1

assert hash(mapper0) != hash(mapper1)
assert hash(create_cache_mapper(True)) == hash(mapper0)
assert hash(create_cache_mapper(False)) == hash(mapper1)


@pytest.mark.parametrize("same_names", [False, True])
def test_metadata(tmpdir, same_names):
source = os.path.join(tmpdir, "source")
afile = os.path.join(source, "afile")
os.mkdir(source)
open(afile, "w").write("test")

fs = fsspec.filesystem(
"filecache",
target_protocol="file",
cache_storage=os.path.join(tmpdir, "cache"),
same_names=same_names,
)

with fs.open(afile, "rb") as f:
assert f.read(5) == b"test"

afile_posix = make_path_posix(afile)
detail = fs.cached_files[0][afile_posix]
assert sorted(detail.keys()) == ["blocks", "fn", "original", "time", "uid"]
assert isinstance(detail["blocks"], bool)
assert isinstance(detail["fn"], str)
assert isinstance(detail["time"], float)
assert isinstance(detail["uid"], str)

assert detail["original"] == afile_posix
assert detail["fn"] == fs._mapper(afile_posix)
if same_names:
assert detail["fn"] == "afile"


def test_idempotent():
fs = CachingFileSystem("file")
fs2 = CachingFileSystem("file")
Expand Down Expand Up @@ -154,7 +211,7 @@ def test_clear():


def test_clear_expired(tmp_path):
def __ager(cache_fn, fn):
def __ager(cache_fn, fn, del_fn=False):
"""
Modify the cache file to virtually add time lag to selected files.
Expand All @@ -164,6 +221,8 @@ def __ager(cache_fn, fn):
cache path
fn: str
file name to be modified
del_fn: bool
whether or not to delete 'fn' from cache details
"""
import pathlib
import time
Expand All @@ -174,6 +233,8 @@ def __ager(cache_fn, fn):
fn_posix = pathlib.Path(fn).as_posix()
cached_files[fn_posix]["time"] = cached_files[fn_posix]["time"] - 691200
assert os.access(cache_fn, os.W_OK), "Cache is not writable"
if del_fn:
del cached_files[fn_posix]["fn"]
with open(cache_fn, "wb") as f:
pickle.dump(cached_files, f)
time.sleep(1)
Expand Down Expand Up @@ -255,6 +316,22 @@ def __ager(cache_fn, fn):
fs.clear_expired_cache()
assert not fs._check_file(str(f4))

# check cache metadata lacking 'fn' raises RuntimeError.
fs = fsspec.filesystem(
"filecache",
target_protocol="file",
cache_storage=str(cache1),
same_names=True,
cache_check=1,
)
assert fs.cat(str(f1)) == data

cache_fn = os.path.join(fs.storage[-1], "cache")
__ager(cache_fn, f1, del_fn=True)

with pytest.raises(RuntimeError, match="Cache metadata does not contain 'fn' for"):
fs.clear_expired_cache()


def test_pop():
import tempfile
Expand Down
1 change: 1 addition & 0 deletions fsspec/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ def test_chained_equivalent():
# since the parameters don't quite match. Also, the url understood by the two
# of s are not the same (path gets munged a bit differently)
assert of.fs == of2.fs
assert hash(of.fs) == hash(of2.fs)
assert of.open().read() == of2.open().read()


Expand Down

0 comments on commit aacc5f2

Please sign in to comment.