Skip to content

Commit

Permalink
fix: memory leak in memorized decorator
Browse files Browse the repository at this point in the history
update doc

wip

wip

fix lint
  • Loading branch information
zhaoyongjie committed Mar 1, 2022
1 parent ca93d63 commit 88eea26
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 103 deletions.
8 changes: 6 additions & 2 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def get_engine(
cls,
database: "Database",
schema: Optional[str] = None,
source: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
user_name = utils.get_username()
return database.get_sqla_engine(
Expand Down Expand Up @@ -1144,7 +1144,11 @@ def process_statement(

@classmethod
def estimate_query_cost(
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
cls,
database: "Database",
schema: str,
sql: str,
source: Optional[utils.QuerySource] = None,
) -> List[Dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
Expand Down
23 changes: 23 additions & 0 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,29 @@ def get_sqla_engine(
nullpool: bool = True,
user_name: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
) -> Engine:
cache_key = (
f"{self.impersonate_user}"
f"{self.sqlalchemy_uri_decrypted}"
f"{json.dumps(self.get_extra())}"
)
return self._get_sqla_engine(
schema=schema,
nullpool=nullpool,
user_name=user_name,
source=source,
residual_cache_key=cache_key,
)

# pylint: disable=too-many-arguments,unused-argument
@memoized
def _get_sqla_engine(
self,
schema: Optional[str] = None,
nullpool: bool = True,
user_name: Optional[str] = None,
source: Optional[utils.QuerySource] = None,
residual_cache_key: Optional[str] = None,
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
Expand Down
76 changes: 21 additions & 55 deletions superset/utils/memoized.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,67 +15,33 @@
# specific language governing permissions and limitations
# under the License.
import functools
from typing import Any, Callable, Dict, Optional, Tuple, Type
from datetime import datetime, timedelta
from typing import Any, Callable


class _memoized:
"""Decorator that caches a function's return value each time it is called
If called later with the same arguments, the cached value is returned, and
not re-evaluated.
Define ``watch`` as a tuple of attribute names if this Decorator
should account for instance variable changes.
def _memoized(seconds: int = 24 * 60 * 60, maxsize: int = 1024,) -> Callable[..., Any]:
"""
A simple wrapper of functools.lru_cache, encapsulated for thread safety
:param seconds: LRU expired time, seconds
:param maxsize: LRU size
:return: a wrapped function by LRU
"""

def __init__(
self, func: Callable[..., Any], watch: Optional[Tuple[str, ...]] = None
) -> None:
self.func = func
self.cache: Dict[Any, Any] = {}
self.is_method = False
self.watch = watch or ()

def __call__(self, *args: Any, **kwargs: Any) -> Any:
key = [args, frozenset(kwargs.items())]
if self.is_method:
key.append(tuple(getattr(args[0], v, None) for v in self.watch))
key = tuple(key) # type: ignore
try:
if key in self.cache:
return self.cache[key]
except TypeError as ex:
# Uncachable -- for instance, passing a list as an argument.
raise TypeError("Function cannot be memoized") from ex
value = self.func(*args, **kwargs)
try:
self.cache[key] = value
except TypeError as ex:
raise TypeError("Function cannot be memoized") from ex
return value

def __repr__(self) -> str:
"""Return the function's docstring."""
return self.func.__doc__ or ""
def wrapper_cache(func: Callable[..., Any]) -> Callable[..., Any]:
lru: Any = functools.lru_cache(maxsize=maxsize)(func)
lru.lifetime = timedelta(seconds=seconds)
lru.expiration = datetime.utcnow() + lru.lifetime

def __get__(
self, obj: Any, objtype: Type[Any]
) -> functools.partial: # type: ignore
if not self.is_method:
self.is_method = True
# Support instance methods.
func = functools.partial(self.__call__, obj)
func.__func__ = self.func # type: ignore
return func
@functools.wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Callable[..., Any]:
if datetime.utcnow() >= lru.expiration:
lru.cache_clear()
lru.expiration = datetime.utcnow() + lru.lifetime
return lru(*args, **kwargs)

return wrapped_func

def memoized(
func: Optional[Callable[..., Any]] = None, watch: Optional[Tuple[str, ...]] = None
) -> Callable[..., Any]:
if func:
return _memoized(func)
return wrapper_cache

def wrapper(f: Callable[..., Any]) -> Callable[..., Any]:
return _memoized(f, watch)

return wrapper
memoized = _memoized()
76 changes: 30 additions & 46 deletions tests/unit_tests/memoized_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time

from pytest import mark

from superset.utils.memoized import memoized
from superset.utils.memoized import _memoized, memoized


@mark.unittest
class TestMemoized:
def test_memoized_on_functions(self):
watcher = {"val": 0}

@memoized
def test_function(a, b, c):
watcher["val"] += 1
return a * b * c
return {"key": a + b + c}

result1 = test_function(1, 2, 3)
result2 = test_function(1, 2, 3)
assert result1 == result2
assert watcher["val"] == 1
assert result1 is result2

def test_memoized_on_methods(self):
class test_class:
Expand All @@ -49,48 +46,35 @@ def test_method(self, a, b, c):
instance = test_class(5)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
assert result1 == result2
assert result1 is result2
assert instance.watcher == 1
instance.num = 10
assert result2 == instance.test_method(1, 2, 3)

def test_memoized_on_methods_with_watches(self):
class test_class:
def __init__(self, x, y):
self.x = x
self.y = y
self.watcher = 0
def test_memorized_size(self):
new_memoized = _memoized(maxsize=1)

@memoized(watch=("x", "y"))
def test_method(self, a, b, c):
self.watcher += 1
return a * b * c * self.x * self.y
@new_memoized
def test_add(a, b):
# return a reference type instead of primal type
return {"key": a + b}

instance = test_class(3, 12)
result1 = instance.test_method(1, 2, 3)
result2 = instance.test_method(1, 2, 3)
assert result1 == result2
assert instance.watcher == 1
result3 = instance.test_method(2, 3, 4)
assert instance.watcher == 2
result4 = instance.test_method(2, 3, 4)
assert instance.watcher == 2
assert result3 == result4
assert result3 != result1
instance.x = 1
result5 = instance.test_method(2, 3, 4)
assert instance.watcher == 3
assert result5 != result4
result6 = instance.test_method(2, 3, 4)
assert instance.watcher == 3
assert result6 == result5
instance.x = 10
instance.y = 10
result7 = instance.test_method(2, 3, 4)
assert instance.watcher == 4
assert result7 != result6
instance.x = 3
instance.y = 12
result8 = instance.test_method(1, 2, 3)
assert instance.watcher == 4
assert result1 == result8
result1 = test_add(1, 2)
# clear cache
test_add(2, 3)
result2 = test_add(1, 2)
assert result1 is not result2

def test_memorized_expire(self):
new_memoized = _memoized(seconds=1)

@new_memoized
def test_add(a, b):
# return a reference type instead of primal type
return {"key": a + b}

result1 = test_add(1, 2)
# clear cache
time.sleep(2)
result2 = test_add(1, 2)
assert result1 is not result2

0 comments on commit 88eea26

Please sign in to comment.