diff --git a/src/cachier/config.py b/src/cachier/config.py index 54d45d1..6731a4d 100644 --- a/src/cachier/config.py +++ b/src/cachier/config.py @@ -46,8 +46,9 @@ class CacheEntry: value: Any time: datetime stale: bool - being_calculated: bool - condition: Optional[threading.Condition] = None + _processing: bool + _condition: Optional[threading.Condition] = None + _completed: bool = False def _update_with_defaults( diff --git a/src/cachier/core.py b/src/cachier/core.py index b86c16e..a56a7e2 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -14,7 +14,7 @@ from collections import OrderedDict from concurrent.futures import ThreadPoolExecutor from functools import wraps -from typing import Optional, Union +from typing import Any, Optional, Union from warnings import warn from .config import ( @@ -55,13 +55,11 @@ def _function_thread(core, key, func, args, kwds): print(f"Function call failed with the following exception:\n{exc}") -def _calc_entry(core, key, func, args, kwds): +def _calc_entry(core, key, func, args, kwds) -> Optional[Any]: + core.mark_entry_being_calculated(key) try: - core.mark_entry_being_calculated(key) - # _get_executor().submit(core.mark_entry_being_calculated, key) func_res = func(*args, **kwds) core.set_entry(key, func_res) - # _get_executor().submit(core.set_entry, key, func_res) return func_res finally: core.mark_entry_not_calculated(key) @@ -242,9 +240,8 @@ def func_wrapper(*args, **kwds): func, _is_method=core.func_is_method, args=args, kwds=kwds ) - _print = lambda x: None # noqa: E731 - if verbose: - _print = print + _print = print if verbose else lambda x: None + if ignore_cache or not _global_params.caching_enabled: return ( func(args[0], **kwargs) @@ -254,7 +251,9 @@ def func_wrapper(*args, **kwds): key, entry = core.get_entry((), kwargs) if overwrite_cache: return _calc_entry(core, key, func, args, kwds) - if entry is None: + if entry is None or ( + not entry._completed and not entry._processing + ): _print("No entry found. No current calc. Calling like a boss.") return _calc_entry(core, key, func, args, kwds) _print("Entry found.") @@ -265,7 +264,7 @@ def func_wrapper(*args, **kwds): _print("And it is fresh!") return entry.value _print("But it is stale... :(") - if entry.being_calculated: + if entry._processing: if _next_time: _print("Returning stale.") return entry.value # return stale val @@ -276,8 +275,8 @@ def func_wrapper(*args, **kwds): return _calc_entry(core, key, func, args, kwds) if _next_time: _print("Async calc and return stale") + core.mark_entry_being_calculated(key) try: - core.mark_entry_being_calculated(key) _get_executor().submit( _function_thread, core, key, func, args, kwds ) @@ -286,7 +285,7 @@ def func_wrapper(*args, **kwds): return entry.value _print("Calling decorated function and waiting") return _calc_entry(core, key, func, args, kwds) - if entry.being_calculated: + if entry._processing: _print("No value but being calculated. Waiting.") try: return core.wait_on_entry_calc(key) diff --git a/src/cachier/cores/memory.py b/src/cachier/cores/memory.py index 2a84666..221b4a7 100644 --- a/src/cachier/cores/memory.py +++ b/src/cachier/cores/memory.py @@ -26,61 +26,65 @@ def get_entry_by_key( return key, self.cache.get(self._hash_func_key(key), None) def set_entry(self, key: str, func_res: Any) -> None: + hash_key = self._hash_func_key(key) with self.lock: try: # we need to retain the existing condition so that # mark_entry_not_calculated can notify all possibly-waiting # threads about it - cond = self.cache[self._hash_func_key(key)].condition + cond = self.cache[hash_key]._condition except KeyError: # pragma: no cover cond = None - self.cache[self._hash_func_key(key)] = CacheEntry( + self.cache[hash_key] = CacheEntry( value=func_res, time=datetime.now(), stale=False, - being_calculated=False, - condition=cond, + _processing=False, + _condition=cond, + _completed=True, ) def mark_entry_being_calculated(self, key: str) -> None: with self.lock: condition = threading.Condition() + hash_key = self._hash_func_key(key) + if hash_key in self.cache: + self.cache[hash_key]._processing = True + self.cache[hash_key]._condition = condition # condition.acquire() - try: - self.cache[self._hash_func_key(key)].being_calculated = True - self.cache[self._hash_func_key(key)].condition = condition - except KeyError: - self.cache[self._hash_func_key(key)] = CacheEntry( + else: + self.cache[hash_key] = CacheEntry( value=None, time=datetime.now(), stale=False, - being_calculated=True, - condition=condition, + _processing=True, + _condition=condition, ) def mark_entry_not_calculated(self, key: str) -> None: + hash_key = self._hash_func_key(key) with self.lock: - try: - entry = self.cache[self._hash_func_key(key)] - except KeyError: # pragma: no cover + if hash_key not in self.cache: return # that's ok, we don't need an entry in that case - entry.being_calculated = False - cond = entry.condition + entry = self.cache[hash_key] + entry._processing = False + cond = entry._condition if cond: cond.acquire() cond.notify_all() cond.release() - entry.condition = None + entry._condition = None def wait_on_entry_calc(self, key: str) -> Any: + hash_key = self._hash_func_key(key) with self.lock: # pragma: no cover - entry = self.cache[self._hash_func_key(key)] - if not entry.being_calculated: + entry = self.cache[hash_key] + if not entry._processing: return entry.value - entry.condition.acquire() - entry.condition.wait() - entry.condition.release() - return self.cache[self._hash_func_key(key)].value + entry._condition.acquire() + entry._condition.wait() + entry._condition.release() + return self.cache[hash_key].value def clear_cache(self) -> None: with self.lock: @@ -89,5 +93,5 @@ def clear_cache(self) -> None: def clear_being_calculated(self) -> None: with self.lock: for entry in self.cache.values(): - entry.being_calculated = False - entry.condition = None + entry._processing = False + entry._condition = None diff --git a/src/cachier/cores/mongo.py b/src/cachier/cores/mongo.py index d0a7041..7182039 100644 --- a/src/cachier/cores/mongo.py +++ b/src/cachier/cores/mongo.py @@ -73,20 +73,14 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: ) if not res: return key, None - try: - entry = CacheEntry( - value=pickle.loads(res["value"]), # noqa: S301 - time=res.get("time", None), - stale=res.get("stale", False), - being_calculated=res.get("being_calculated", False), - ) - except KeyError: - entry = CacheEntry( - value=None, - time=res.get("time", None), - stale=res.get("stale", False), - being_calculated=res.get("being_calculated", False), - ) + val = pickle.loads(res["value"]) if "value" in res else None # noqa: S301 + entry = CacheEntry( + value=val, + time=res.get("time", None), + stale=res.get("stale", False), + _processing=res.get("processing", False), + _completed=res.get("completed", False), + ) return key, entry def set_entry(self, key: str, func_res: Any) -> None: @@ -100,7 +94,8 @@ def set_entry(self, key: str, func_res: Any) -> None: "value": Binary(thebytes), "time": datetime.now(), "stale": False, - "being_calculated": False, + "processing": False, + "completed": True, } }, upsert=True, @@ -109,7 +104,7 @@ def set_entry(self, key: str, func_res: Any) -> None: def mark_entry_being_calculated(self, key: str) -> None: self.mongo_collection.update_one( filter={"func": self._func_str, "key": key}, - update={"$set": {"being_calculated": True}}, + update={"$set": {"processing": True}}, upsert=True, ) @@ -120,7 +115,7 @@ def mark_entry_not_calculated(self, key: str) -> None: "func": self._func_str, "key": key, }, - update={"$set": {"being_calculated": False}}, + update={"$set": {"processing": False}}, upsert=False, # should not insert in this case ) @@ -132,7 +127,7 @@ def wait_on_entry_calc(self, key: str) -> Any: key, entry = self.get_entry_by_key(key) if entry is None: raise RecalculationNeeded() - if not entry.being_calculated: + if not entry._processing: return entry.value self.check_calc_timeout(time_spent) @@ -143,7 +138,7 @@ def clear_being_calculated(self) -> None: self.mongo_collection.update_many( filter={ "func": self._func_str, - "being_calculated": True, + "processing": True, }, - update={"$set": {"being_calculated": False}}, + update={"$set": {"processing": False}}, ) diff --git a/src/cachier/cores/pickle.py b/src/cachier/cores/pickle.py index 196ea6f..4b0d269 100644 --- a/src/cachier/cores/pickle.py +++ b/src/cachier/cores/pickle.py @@ -46,17 +46,14 @@ def inject_observer(self, observer) -> None: self.observer = observer def _check_calculation(self) -> None: - # print('checking calc') entry = self.core.get_entry_by_key(self.key, True)[1] - # print(self.key) - # print(entry) try: - if not entry.being_calculated: + if not entry._processing: # print('stopping observer!') self.value = entry.value self.observer.stop() # else: - # print('NOT stopping observer... :(') + # print('NOT stopping observer... :(') except TypeError: self.value = None self.observer.stop() @@ -169,7 +166,8 @@ def set_entry(self, key: str, func_res: Any) -> None: value=func_res, time=datetime.now(), stale=False, - being_calculated=False, + _processing=False, + _completed=True, ) if self.separate_files: self._save_cache(key_data, key) @@ -186,14 +184,14 @@ def mark_entry_being_calculated_separate_files(self, key: str) -> None: value=None, time=datetime.now(), stale=False, - being_calculated=True, + _processing=True, ), key=key, ) def mark_entry_not_calculated_separate_files(self, key: str) -> None: _, entry = self.get_entry_by_key(key) - entry.being_calculated = False + entry._processing = False self._save_cache(entry, key=key) def mark_entry_being_calculated(self, key: str) -> None: @@ -203,14 +201,14 @@ def mark_entry_being_calculated(self, key: str) -> None: with self.lock: cache = self._get_cache() - try: - cache[key].being_calculated = True - except KeyError: + if key in cache: + cache[key]._processing = True + else: cache[key] = CacheEntry( value=None, time=datetime.now(), stale=False, - being_calculated=True, + _processing=True, ) self._save_cache(cache) @@ -221,7 +219,7 @@ def mark_entry_not_calculated(self, key: str) -> None: cache = self._get_cache() # that's ok, we don't need an entry in that case if isinstance(cache, dict) and key in cache: - cache[key].being_calculated = False + cache[key]._processing = False self._save_cache(cache) def wait_on_entry_calc(self, key: str) -> Any: @@ -233,7 +231,7 @@ def wait_on_entry_calc(self, key: str) -> Any: self._reload_cache() entry = self._get_cache()[key] filename = self.cache_fname - if not entry.being_calculated: + if not entry._processing: return entry.value event_handler = _PickleCore.CacheChangeHandler( filename=filename, core=self, key=key @@ -263,5 +261,5 @@ def clear_being_calculated(self) -> None: with self.lock: cache = self._get_cache() for key in cache: - cache[key].being_calculated = False + cache[key]._processing = False self._save_cache(cache) diff --git a/tests/test_general.py b/tests/test_general.py index b21f5e8..35b2e18 100644 --- a/tests/test_general.py +++ b/tests/test_general.py @@ -492,3 +492,15 @@ def fn_minus(a, b=2): assert count_p == 1 assert count_m == 1 + + +@pytest.mark.parametrize("backend", ["memory", "pickle"]) +def test_raise_exception(tmpdir, backend: str): + @cachier.cachier(cache_dir=tmpdir, backend=backend, allow_none=True) + def tmp_test(_): + raise RuntimeError("always raise") + + with pytest.raises(RuntimeError): + tmp_test(123) + with pytest.raises(RuntimeError): + tmp_test(123) diff --git a/tests/test_mongo_core.py b/tests/test_mongo_core.py index 21a0622..92371f7 100644 --- a/tests/test_mongo_core.py +++ b/tests/test_mongo_core.py @@ -275,7 +275,7 @@ def _stalled_func(): def test_stalled_mong_db_core(monkeypatch): def mock_get_entry(self, args, kwargs): return "key", CacheEntry( - being_calculated=True, value=None, time=None, stale=None + _processing=True, value=None, time=None, stale=None ) def mock_get_entry_by_key(self, key): @@ -300,7 +300,7 @@ def mock_get_entry_2(self, args, kwargs): return "key", CacheEntry( value=1, time=datetime.datetime.now() - datetime.timedelta(seconds=10), - being_calculated=True, + _processing=True, stale=None, )