From 88eea2699a531d168853fea0d71c5fd14ece56a9 Mon Sep 17 00:00:00 2001 From: Yongjie Zhao Date: Tue, 2 Nov 2021 18:52:04 +0800 Subject: [PATCH] fix: memory leak in memorized decorator update doc wip wip fix lint --- superset/db_engine_specs/base.py | 8 +++- superset/models/core.py | 23 +++++++++ superset/utils/memoized.py | 76 +++++++++--------------------- tests/unit_tests/memoized_tests.py | 76 ++++++++++++------------------ 4 files changed, 80 insertions(+), 103 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index d7e457baa8c01..326af9270ab79 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -372,7 +372,7 @@ def get_engine( cls, database: "Database", schema: Optional[str] = None, - source: Optional[str] = None, + source: Optional[utils.QuerySource] = None, ) -> Engine: user_name = utils.get_username() return database.get_sqla_engine( @@ -1144,7 +1144,11 @@ def process_statement( @classmethod def estimate_query_cost( - cls, database: "Database", schema: str, sql: str, source: Optional[str] = None + cls, + database: "Database", + schema: str, + sql: str, + source: Optional[utils.QuerySource] = None, ) -> List[Dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. diff --git a/superset/models/core.py b/superset/models/core.py index 7798ddf05930d..95f046e061599 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -349,6 +349,29 @@ def get_sqla_engine( nullpool: bool = True, user_name: Optional[str] = None, source: Optional[utils.QuerySource] = None, + ) -> Engine: + cache_key = ( + f"{self.impersonate_user}" + f"{self.sqlalchemy_uri_decrypted}" + f"{json.dumps(self.get_extra())}" + ) + return self._get_sqla_engine( + schema=schema, + nullpool=nullpool, + user_name=user_name, + source=source, + residual_cache_key=cache_key, + ) + + # pylint: disable=too-many-arguments,unused-argument + @memoized + def _get_sqla_engine( + self, + schema: Optional[str] = None, + nullpool: bool = True, + user_name: Optional[str] = None, + source: Optional[utils.QuerySource] = None, + residual_cache_key: Optional[str] = None, ) -> Engine: extra = self.get_extra() sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) diff --git a/superset/utils/memoized.py b/superset/utils/memoized.py index 153542fbb7b16..118dc6140fb06 100644 --- a/superset/utils/memoized.py +++ b/superset/utils/memoized.py @@ -15,67 +15,33 @@ # specific language governing permissions and limitations # under the License. import functools -from typing import Any, Callable, Dict, Optional, Tuple, Type +from datetime import datetime, timedelta +from typing import Any, Callable -class _memoized: - """Decorator that caches a function's return value each time it is called - - If called later with the same arguments, the cached value is returned, and - not re-evaluated. - - Define ``watch`` as a tuple of attribute names if this Decorator - should account for instance variable changes. +def _memoized(seconds: int = 24 * 60 * 60, maxsize: int = 1024,) -> Callable[..., Any]: + """ + A simple wrapper of functools.lru_cache, encapsulated for thread safety + :param seconds: LRU expired time, seconds + :param maxsize: LRU size + :return: a wrapped function by LRU """ - def __init__( - self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None - ) -> None: - self.func = func - self.cache: Dict[Any, Any] = {} - self.is_method = False - self.watch = watch or () - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - key = [args, frozenset(kwargs.items())] - if self.is_method: - key.append(tuple(getattr(args[0], v, None) for v in self.watch)) - key = tuple(key) # type: ignore - try: - if key in self.cache: - return self.cache[key] - except TypeError as ex: - # Uncachable -- for instance, passing a list as an argument. - raise TypeError("Function cannot be memoized") from ex - value = self.func(*args, **kwargs) - try: - self.cache[key] = value - except TypeError as ex: - raise TypeError("Function cannot be memoized") from ex - return value - - def __repr__(self) -> str: - """Return the function's docstring.""" - return self.func.__doc__ or "" + def wrapper_cache(func: Callable[..., Any]) -> Callable[..., Any]: + lru: Any = functools.lru_cache(maxsize=maxsize)(func) + lru.lifetime = timedelta(seconds=seconds) + lru.expiration = datetime.utcnow() + lru.lifetime - def __get__( - self, obj: Any, objtype: Type[Any] - ) -> functools.partial: # type: ignore - if not self.is_method: - self.is_method = True - # Support instance methods. - func = functools.partial(self.__call__, obj) - func.__func__ = self.func # type: ignore - return func + @functools.wraps(func) + def wrapped_func(*args: Any, **kwargs: Any) -> Callable[..., Any]: + if datetime.utcnow() >= lru.expiration: + lru.cache_clear() + lru.expiration = datetime.utcnow() + lru.lifetime + return lru(*args, **kwargs) + return wrapped_func -def memoized( - func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None -) -> Callable[..., Any]: - if func: - return _memoized(func) + return wrapper_cache - def wrapper(f: Callable[..., Any]) -> Callable[..., Any]: - return _memoized(f, watch) - return wrapper +memoized = _memoized() diff --git a/tests/unit_tests/memoized_tests.py b/tests/unit_tests/memoized_tests.py index 3b3f436606f51..c81ece441eae3 100644 --- a/tests/unit_tests/memoized_tests.py +++ b/tests/unit_tests/memoized_tests.py @@ -14,26 +14,23 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import time from pytest import mark -from superset.utils.memoized import memoized +from superset.utils.memoized import _memoized, memoized @mark.unittest class TestMemoized: def test_memoized_on_functions(self): - watcher = {"val": 0} - @memoized def test_function(a, b, c): - watcher["val"] += 1 - return a * b * c + return {"key": a + b + c} result1 = test_function(1, 2, 3) result2 = test_function(1, 2, 3) - assert result1 == result2 - assert watcher["val"] == 1 + assert result1 is result2 def test_memoized_on_methods(self): class test_class: @@ -49,48 +46,35 @@ def test_method(self, a, b, c): instance = test_class(5) result1 = instance.test_method(1, 2, 3) result2 = instance.test_method(1, 2, 3) - assert result1 == result2 + assert result1 is result2 assert instance.watcher == 1 instance.num = 10 assert result2 == instance.test_method(1, 2, 3) - def test_memoized_on_methods_with_watches(self): - class test_class: - def __init__(self, x, y): - self.x = x - self.y = y - self.watcher = 0 + def test_memorized_size(self): + new_memoized = _memoized(maxsize=1) - @memoized(watch=("x", "y")) - def test_method(self, a, b, c): - self.watcher += 1 - return a * b * c * self.x * self.y + @new_memoized + def test_add(a, b): + # return a reference type instead of primal type + return {"key": a + b} - instance = test_class(3, 12) - result1 = instance.test_method(1, 2, 3) - result2 = instance.test_method(1, 2, 3) - assert result1 == result2 - assert instance.watcher == 1 - result3 = instance.test_method(2, 3, 4) - assert instance.watcher == 2 - result4 = instance.test_method(2, 3, 4) - assert instance.watcher == 2 - assert result3 == result4 - assert result3 != result1 - instance.x = 1 - result5 = instance.test_method(2, 3, 4) - assert instance.watcher == 3 - assert result5 != result4 - result6 = instance.test_method(2, 3, 4) - assert instance.watcher == 3 - assert result6 == result5 - instance.x = 10 - instance.y = 10 - result7 = instance.test_method(2, 3, 4) - assert instance.watcher == 4 - assert result7 != result6 - instance.x = 3 - instance.y = 12 - result8 = instance.test_method(1, 2, 3) - assert instance.watcher == 4 - assert result1 == result8 + result1 = test_add(1, 2) + # clear cache + test_add(2, 3) + result2 = test_add(1, 2) + assert result1 is not result2 + + def test_memorized_expire(self): + new_memoized = _memoized(seconds=1) + + @new_memoized + def test_add(a, b): + # return a reference type instead of primal type + return {"key": a + b} + + result1 = test_add(1, 2) + # clear cache + time.sleep(2) + result2 = test_add(1, 2) + assert result1 is not result2