Skip to content

Commit

Permalink
skip cashing for exceptions (#233)
Browse files Browse the repository at this point in the history
* adding test

* _print

* cleaning

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* with completed

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Borda and pre-commit-ci[bot] authored Oct 18, 2024
1 parent 731b66e commit 67bcf72
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 76 deletions.
5 changes: 3 additions & 2 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ class CacheEntry:
value: Any
time: datetime
stale: bool
being_calculated: bool
condition: Optional[threading.Condition] = None
_processing: bool
_condition: Optional[threading.Condition] = None
_completed: bool = False


def _update_with_defaults(
Expand Down
23 changes: 11 additions & 12 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from functools import wraps
from typing import Optional, Union
from typing import Any, Optional, Union
from warnings import warn

from .config import (
Expand Down Expand Up @@ -55,13 +55,11 @@ def _function_thread(core, key, func, args, kwds):
print(f"Function call failed with the following exception:\n{exc}")


def _calc_entry(core, key, func, args, kwds):
def _calc_entry(core, key, func, args, kwds) -> Optional[Any]:
core.mark_entry_being_calculated(key)
try:
core.mark_entry_being_calculated(key)
# _get_executor().submit(core.mark_entry_being_calculated, key)
func_res = func(*args, **kwds)
core.set_entry(key, func_res)
# _get_executor().submit(core.set_entry, key, func_res)
return func_res
finally:
core.mark_entry_not_calculated(key)
Expand Down Expand Up @@ -242,9 +240,8 @@ def func_wrapper(*args, **kwds):
func, _is_method=core.func_is_method, args=args, kwds=kwds
)

_print = lambda x: None # noqa: E731
if verbose:
_print = print
_print = print if verbose else lambda x: None

if ignore_cache or not _global_params.caching_enabled:
return (
func(args[0], **kwargs)
Expand All @@ -254,7 +251,9 @@ def func_wrapper(*args, **kwds):
key, entry = core.get_entry((), kwargs)
if overwrite_cache:
return _calc_entry(core, key, func, args, kwds)
if entry is None:
if entry is None or (
not entry._completed and not entry._processing
):
_print("No entry found. No current calc. Calling like a boss.")
return _calc_entry(core, key, func, args, kwds)
_print("Entry found.")
Expand All @@ -265,7 +264,7 @@ def func_wrapper(*args, **kwds):
_print("And it is fresh!")
return entry.value
_print("But it is stale... :(")
if entry.being_calculated:
if entry._processing:
if _next_time:
_print("Returning stale.")
return entry.value # return stale val
Expand All @@ -276,8 +275,8 @@ def func_wrapper(*args, **kwds):
return _calc_entry(core, key, func, args, kwds)
if _next_time:
_print("Async calc and return stale")
core.mark_entry_being_calculated(key)
try:
core.mark_entry_being_calculated(key)
_get_executor().submit(
_function_thread, core, key, func, args, kwds
)
Expand All @@ -286,7 +285,7 @@ def func_wrapper(*args, **kwds):
return entry.value
_print("Calling decorated function and waiting")
return _calc_entry(core, key, func, args, kwds)
if entry.being_calculated:
if entry._processing:
_print("No value but being calculated. Waiting.")
try:
return core.wait_on_entry_calc(key)
Expand Down
54 changes: 29 additions & 25 deletions src/cachier/cores/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,61 +26,65 @@ def get_entry_by_key(
return key, self.cache.get(self._hash_func_key(key), None)

def set_entry(self, key: str, func_res: Any) -> None:
hash_key = self._hash_func_key(key)
with self.lock:
try:
# we need to retain the existing condition so that
# mark_entry_not_calculated can notify all possibly-waiting
# threads about it
cond = self.cache[self._hash_func_key(key)].condition
cond = self.cache[hash_key]._condition
except KeyError: # pragma: no cover
cond = None
self.cache[self._hash_func_key(key)] = CacheEntry(
self.cache[hash_key] = CacheEntry(
value=func_res,
time=datetime.now(),
stale=False,
being_calculated=False,
condition=cond,
_processing=False,
_condition=cond,
_completed=True,
)

def mark_entry_being_calculated(self, key: str) -> None:
with self.lock:
condition = threading.Condition()
hash_key = self._hash_func_key(key)
if hash_key in self.cache:
self.cache[hash_key]._processing = True
self.cache[hash_key]._condition = condition
# condition.acquire()
try:
self.cache[self._hash_func_key(key)].being_calculated = True
self.cache[self._hash_func_key(key)].condition = condition
except KeyError:
self.cache[self._hash_func_key(key)] = CacheEntry(
else:
self.cache[hash_key] = CacheEntry(
value=None,
time=datetime.now(),
stale=False,
being_calculated=True,
condition=condition,
_processing=True,
_condition=condition,
)

def mark_entry_not_calculated(self, key: str) -> None:
hash_key = self._hash_func_key(key)
with self.lock:
try:
entry = self.cache[self._hash_func_key(key)]
except KeyError: # pragma: no cover
if hash_key not in self.cache:
return # that's ok, we don't need an entry in that case
entry.being_calculated = False
cond = entry.condition
entry = self.cache[hash_key]
entry._processing = False
cond = entry._condition
if cond:
cond.acquire()
cond.notify_all()
cond.release()
entry.condition = None
entry._condition = None

def wait_on_entry_calc(self, key: str) -> Any:
hash_key = self._hash_func_key(key)
with self.lock: # pragma: no cover
entry = self.cache[self._hash_func_key(key)]
if not entry.being_calculated:
entry = self.cache[hash_key]
if not entry._processing:
return entry.value
entry.condition.acquire()
entry.condition.wait()
entry.condition.release()
return self.cache[self._hash_func_key(key)].value
entry._condition.acquire()
entry._condition.wait()
entry._condition.release()
return self.cache[hash_key].value

def clear_cache(self) -> None:
with self.lock:
Expand All @@ -89,5 +93,5 @@ def clear_cache(self) -> None:
def clear_being_calculated(self) -> None:
with self.lock:
for entry in self.cache.values():
entry.being_calculated = False
entry.condition = None
entry._processing = False
entry._condition = None
35 changes: 15 additions & 20 deletions src/cachier/cores/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,14 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
)
if not res:
return key, None
try:
entry = CacheEntry(
value=pickle.loads(res["value"]), # noqa: S301
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
except KeyError:
entry = CacheEntry(
value=None,
time=res.get("time", None),
stale=res.get("stale", False),
being_calculated=res.get("being_calculated", False),
)
val = pickle.loads(res["value"]) if "value" in res else None # noqa: S301
entry = CacheEntry(
value=val,
time=res.get("time", None),
stale=res.get("stale", False),
_processing=res.get("processing", False),
_completed=res.get("completed", False),
)
return key, entry

def set_entry(self, key: str, func_res: Any) -> None:
Expand All @@ -100,7 +94,8 @@ def set_entry(self, key: str, func_res: Any) -> None:
"value": Binary(thebytes),
"time": datetime.now(),
"stale": False,
"being_calculated": False,
"processing": False,
"completed": True,
}
},
upsert=True,
Expand All @@ -109,7 +104,7 @@ def set_entry(self, key: str, func_res: Any) -> None:
def mark_entry_being_calculated(self, key: str) -> None:
self.mongo_collection.update_one(
filter={"func": self._func_str, "key": key},
update={"$set": {"being_calculated": True}},
update={"$set": {"processing": True}},
upsert=True,
)

Expand All @@ -120,7 +115,7 @@ def mark_entry_not_calculated(self, key: str) -> None:
"func": self._func_str,
"key": key,
},
update={"$set": {"being_calculated": False}},
update={"$set": {"processing": False}},
upsert=False, # should not insert in this case
)

Expand All @@ -132,7 +127,7 @@ def wait_on_entry_calc(self, key: str) -> Any:
key, entry = self.get_entry_by_key(key)
if entry is None:
raise RecalculationNeeded()
if not entry.being_calculated:
if not entry._processing:
return entry.value
self.check_calc_timeout(time_spent)

Expand All @@ -143,7 +138,7 @@ def clear_being_calculated(self) -> None:
self.mongo_collection.update_many(
filter={
"func": self._func_str,
"being_calculated": True,
"processing": True,
},
update={"$set": {"being_calculated": False}},
update={"$set": {"processing": False}},
)
28 changes: 13 additions & 15 deletions src/cachier/cores/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,14 @@ def inject_observer(self, observer) -> None:
self.observer = observer

def _check_calculation(self) -> None:
# print('checking calc')
entry = self.core.get_entry_by_key(self.key, True)[1]
# print(self.key)
# print(entry)
try:
if not entry.being_calculated:
if not entry._processing:
# print('stopping observer!')
self.value = entry.value
self.observer.stop()
# else:
# print('NOT stopping observer... :(')
# print('NOT stopping observer... :(')
except TypeError:
self.value = None
self.observer.stop()
Expand Down Expand Up @@ -169,7 +166,8 @@ def set_entry(self, key: str, func_res: Any) -> None:
value=func_res,
time=datetime.now(),
stale=False,
being_calculated=False,
_processing=False,
_completed=True,
)
if self.separate_files:
self._save_cache(key_data, key)
Expand All @@ -186,14 +184,14 @@ def mark_entry_being_calculated_separate_files(self, key: str) -> None:
value=None,
time=datetime.now(),
stale=False,
being_calculated=True,
_processing=True,
),
key=key,
)

def mark_entry_not_calculated_separate_files(self, key: str) -> None:
_, entry = self.get_entry_by_key(key)
entry.being_calculated = False
entry._processing = False
self._save_cache(entry, key=key)

def mark_entry_being_calculated(self, key: str) -> None:
Expand All @@ -203,14 +201,14 @@ def mark_entry_being_calculated(self, key: str) -> None:

with self.lock:
cache = self._get_cache()
try:
cache[key].being_calculated = True
except KeyError:
if key in cache:
cache[key]._processing = True
else:
cache[key] = CacheEntry(
value=None,
time=datetime.now(),
stale=False,
being_calculated=True,
_processing=True,
)
self._save_cache(cache)

Expand All @@ -221,7 +219,7 @@ def mark_entry_not_calculated(self, key: str) -> None:
cache = self._get_cache()
# that's ok, we don't need an entry in that case
if isinstance(cache, dict) and key in cache:
cache[key].being_calculated = False
cache[key]._processing = False
self._save_cache(cache)

def wait_on_entry_calc(self, key: str) -> Any:
Expand All @@ -233,7 +231,7 @@ def wait_on_entry_calc(self, key: str) -> Any:
self._reload_cache()
entry = self._get_cache()[key]
filename = self.cache_fname
if not entry.being_calculated:
if not entry._processing:
return entry.value
event_handler = _PickleCore.CacheChangeHandler(
filename=filename, core=self, key=key
Expand Down Expand Up @@ -263,5 +261,5 @@ def clear_being_calculated(self) -> None:
with self.lock:
cache = self._get_cache()
for key in cache:
cache[key].being_calculated = False
cache[key]._processing = False
self._save_cache(cache)
12 changes: 12 additions & 0 deletions tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,15 @@ def fn_minus(a, b=2):

assert count_p == 1
assert count_m == 1


@pytest.mark.parametrize("backend", ["memory", "pickle"])
def test_raise_exception(tmpdir, backend: str):
@cachier.cachier(cache_dir=tmpdir, backend=backend, allow_none=True)
def tmp_test(_):
raise RuntimeError("always raise")

with pytest.raises(RuntimeError):
tmp_test(123)
with pytest.raises(RuntimeError):
tmp_test(123)
4 changes: 2 additions & 2 deletions tests/test_mongo_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def _stalled_func():
def test_stalled_mong_db_core(monkeypatch):
def mock_get_entry(self, args, kwargs):
return "key", CacheEntry(
being_calculated=True, value=None, time=None, stale=None
_processing=True, value=None, time=None, stale=None
)

def mock_get_entry_by_key(self, key):
Expand All @@ -300,7 +300,7 @@ def mock_get_entry_2(self, args, kwargs):
return "key", CacheEntry(
value=1,
time=datetime.datetime.now() - datetime.timedelta(seconds=10),
being_calculated=True,
_processing=True,
stale=None,
)

Expand Down

0 comments on commit 67bcf72

Please sign in to comment.