Skip to content

Commit

Permalink
convert _default_params to dataclass (#237)
Browse files Browse the repository at this point in the history
* convert `_default_params` to dataclass

* lint

* copy

* set

* keys

* assert

* global

* global func

* warning

* fix import

* lint

---------

Co-authored-by: Shay Palachy-Affek <shaypal5@users.noreply.github.com>
  • Loading branch information
Borda and shaypal5 authored Oct 17, 2024
1 parent 3bbe33b commit 9dd89b0
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
args: ["--number"]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v4.0.0-alpha.8
rev: v3.1.0
hooks:
- id: prettier
files: \.(json|yml|yaml|toml)
Expand Down
8 changes: 4 additions & 4 deletions src/cachier/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from .config import (
disable_caching,
enable_caching,
get_default_params,
set_default_params,
get_global_params,
set_global_params,
)
from .core import cachier

__all__ = [
"cachier",
"set_default_params",
"get_default_params",
"set_global_params",
"get_global_params",
"enable_caching",
"disable_caching",
]
101 changes: 62 additions & 39 deletions src/cachier/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import hashlib
import os
import pickle
from typing import Optional, TypedDict, Union
from collections.abc import Mapping
from dataclasses import dataclass, replace
from typing import Optional, Union

from ._types import Backend, HashFunc, Mongetter

Expand All @@ -16,35 +18,24 @@ def _default_hash_func(args, kwds):
return hashlib.sha256(serialized).hexdigest()


class Params(TypedDict):
"""Type definition for cachier parameters."""

caching_enabled: bool
hash_func: HashFunc
backend: Backend
mongetter: Optional[Mongetter]
stale_after: datetime.timedelta
next_time: bool
cache_dir: Union[str, os.PathLike]
pickle_reload: bool
separate_files: bool
wait_for_calc_timeout: int
allow_none: bool


_default_params: Params = {
"caching_enabled": True,
"hash_func": _default_hash_func,
"backend": "pickle",
"mongetter": None,
"stale_after": datetime.timedelta.max,
"next_time": False,
"cache_dir": "~/.cachier/",
"pickle_reload": True,
"separate_files": False,
"wait_for_calc_timeout": 0,
"allow_none": False,
}
@dataclass
class Params:
"""Default definition for cachier parameters."""

caching_enabled: bool = True
hash_func: HashFunc = _default_hash_func
backend: Backend = "pickle"
mongetter: Optional[Mongetter] = None
stale_after: datetime.timedelta = datetime.timedelta.max
next_time: bool = False
cache_dir: Union[str, os.PathLike] = "~/.cachier/"
pickle_reload: bool = True
separate_files: bool = False
wait_for_calc_timeout: int = 0
allow_none: bool = False


_global_params = Params()


def _update_with_defaults(
Expand All @@ -57,11 +48,25 @@ def _update_with_defaults(
if kw_name in func_kwargs:
return func_kwargs.pop(kw_name)
if param is None:
return cachier.config._default_params[name]
return getattr(cachier.config._global_params, name)
return param


def set_default_params(**params):
def set_default_params(**params: Mapping) -> None:
"""Configure default parameters applicable to all memoized functions."""
# It is kept for backwards compatibility with desperation warning
import warnings

warnings.warn(
"Called `set_default_params` is deprecated and will be removed."
" Please use `set_global_params` instead.",
DeprecationWarning,
stacklevel=2,
)
set_global_params(**params)


def set_global_params(**params: Mapping) -> None:
"""Configure global parameters applicable to all memoized functions.
This function takes the same keyword parameters as the ones defined in the
Expand All @@ -76,28 +81,46 @@ def set_default_params(**params):
"""
import cachier

valid_params = (
p for p in params.items() if p[0] in cachier.config._default_params
valid_params = {
k: v
for k, v in params.items()
if hasattr(cachier.config._global_params, k)
}
cachier.config._global_params = replace(
cachier.config._global_params, **valid_params
)


def get_default_params() -> Params:
"""Get current set of default parameters."""
# It is kept for backwards compatibility with desperation warning
import warnings

warnings.warn(
"Called `get_default_params` is deprecated and will be removed."
" Please use `get_global_params` instead.",
DeprecationWarning,
stacklevel=2,
)
_default_params.update(valid_params)
return get_global_params()


def get_default_params():
def get_global_params() -> Params:
"""Get current set of default parameters."""
import cachier

return cachier.config._default_params
return cachier.config._global_params


def enable_caching():
"""Enable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = True
cachier.config._global_params.caching_enabled = True


def disable_caching():
"""Disable caching globally."""
import cachier

cachier.config._default_params["caching_enabled"] = False
cachier.config._global_params.caching_enabled = False
5 changes: 3 additions & 2 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
Backend,
HashFunc,
Mongetter,
_default_params,
_update_with_defaults,
)
from .cores.base import RecalculationNeeded, _BaseCore
Expand Down Expand Up @@ -176,6 +175,8 @@ def cachier(
None will not be cached and are recalculated every call.
"""
from .config import _global_params

# Check for deprecated parameters
if hash_params is not None:
message = (
Expand Down Expand Up @@ -244,7 +245,7 @@ def func_wrapper(*args, **kwds):
_print = lambda x: None # noqa: E731
if verbose:
_print = print
if ignore_cache or not _default_params["caching_enabled"]:
if ignore_cache or not _global_params.caching_enabled:
return (
func(args[0], **kwargs)
if core.func_is_method
Expand Down
8 changes: 4 additions & 4 deletions tests/test_core_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import pytest

from cachier import cachier, get_default_params
from cachier import cachier, get_global_params
from cachier.cores.mongo import MissingMongetter


def test_get_default_params():
params = get_default_params()
assert tuple(sorted(params)) == (
params = get_global_params()
assert sorted(vars(params).keys()) == [
"allow_none",
"backend",
"cache_dir",
Expand All @@ -20,7 +20,7 @@ def test_get_default_params():
"separate_files",
"stale_after",
"wait_for_calc_timeout",
)
]


def test_bad_name(name="nope"):
Expand Down
25 changes: 13 additions & 12 deletions tests/test_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,23 @@
import random
import threading
import time
from dataclasses import replace

import pytest

import cachier
from tests.test_mongo_core import _test_mongetter

MONGO_DELTA = datetime.timedelta(seconds=3)
_default_params = cachier.get_default_params().copy()
_copied_defaults = replace(cachier.get_global_params())


def setup_function():
cachier.set_default_params(**_default_params)
cachier.set_global_params(**vars(_copied_defaults))


def teardown_function():
cachier.set_default_params(**_default_params)
cachier.set_global_params(**vars(_copied_defaults))


def test_hash_func_default_param():
Expand All @@ -30,7 +31,7 @@ def slow_hash_func(args, kwds):
def fast_hash_func(args, kwds):
return "hash"

cachier.set_default_params(hash_func=slow_hash_func)
cachier.set_global_params(hash_func=slow_hash_func)

@cachier.cachier()
def global_test_1():
Expand All @@ -51,7 +52,7 @@ def global_test_2():


def test_backend_default_param():
cachier.set_default_params(backend="memory")
cachier.set_global_params(backend="memory")

@cachier.cachier()
def global_test_1():
Expand All @@ -67,7 +68,7 @@ def global_test_2():

@pytest.mark.mongo
def test_mongetter_default_param():
cachier.set_default_params(mongetter=_test_mongetter)
cachier.set_global_params(mongetter=_test_mongetter)

@cachier.cachier()
def global_test_1():
Expand All @@ -82,7 +83,7 @@ def global_test_2():


def test_cache_dir_default_param(tmpdir):
cachier.set_default_params(cache_dir=tmpdir / "1")
cachier.set_global_params(cache_dir=tmpdir / "1")

@cachier.cachier()
def global_test_1():
Expand All @@ -97,7 +98,7 @@ def global_test_2():


def test_separate_files_default_param(tmpdir):
cachier.set_default_params(separate_files=True)
cachier.set_global_params(separate_files=True)

@cachier.cachier(cache_dir=tmpdir / "1")
def global_test_1(arg_1, arg_2):
Expand All @@ -117,7 +118,7 @@ def global_test_2(arg_1, arg_2):


def test_allow_none_default_param(tmpdir):
cachier.set_default_params(
cachier.set_global_params(
allow_none=True,
separate_files=True,
verbose_cache=True,
Expand Down Expand Up @@ -167,7 +168,7 @@ def _stale_after_test(arg_1, arg_2):
"""Some function."""
return random.random() + arg_1 + arg_2

cachier.set_default_params(stale_after=MONGO_DELTA)
cachier.set_global_params(stale_after=MONGO_DELTA)

_stale_after_test.clear_cache()
val1 = _stale_after_test(1, 2)
Expand All @@ -187,7 +188,7 @@ def _stale_after_next_time(arg_1, arg_2):
"""Some function."""
return random.random()

cachier.set_default_params(stale_after=NEXT_AFTER_DELTA, next_time=True)
cachier.set_global_params(stale_after=NEXT_AFTER_DELTA, next_time=True)

_stale_after_next_time.clear_cache()
val1 = _stale_after_next_time(1, 2)
Expand Down Expand Up @@ -217,7 +218,7 @@ def _calls_wait_for_calc_timeout_slow(res_queue):
res = _wait_for_calc_timeout_slow(1, 2)
res_queue.put(res)

cachier.set_default_params(wait_for_calc_timeout=2)
cachier.set_global_params(wait_for_calc_timeout=2)
_wait_for_calc_timeout_slow.clear_cache()
res_queue = queue.Queue()
thread1 = threading.Thread(
Expand Down
4 changes: 3 additions & 1 deletion tests/test_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,17 @@ def test_separate_processes():

def test_global_disable():
@cachier.cachier()
def get_random():
def get_random() -> float:
return random()

get_random.clear_cache()
result_1 = get_random()
result_2 = get_random()
cachier.disable_caching()
assert cachier.config._global_params.caching_enabled is False
result_3 = get_random()
cachier.enable_caching()
assert cachier.config._global_params.caching_enabled is True
result_4 = get_random()
assert result_1 == result_2 == result_4
assert result_1 != result_3
Expand Down
4 changes: 2 additions & 2 deletions tests/test_pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
import pandas as pd

from cachier import cachier
from cachier.core import _default_params
from cachier.config import _global_params


def _get_decorated_func(func, **kwargs):
Expand Down Expand Up @@ -329,7 +329,7 @@ def _bad_cache(arg_1, arg_2):
".tests.test_pickle_core._bad_cache_"
f"{hashlib.sha256(pickle.dumps((0.13, 0.02))).hexdigest()}"
)
EXPANDED_CACHIER_DIR = os.path.expanduser(_default_params["cache_dir"])
EXPANDED_CACHIER_DIR = os.path.expanduser(_global_params.cache_dir)
_BAD_CACHE_FPATH = os.path.join(EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME)
_BAD_CACHE_FPATH_SEPARATE_FILES = os.path.join(
EXPANDED_CACHIER_DIR, _BAD_CACHE_FNAME_SEPARATE_FILES
Expand Down

0 comments on commit 9dd89b0

Please sign in to comment.