Skip to content

Commit c85c4e2

Browse files
author
Matthieu Ancellin
committed
Add cache decorator.
1 parent de2c9b5 commit c85c4e2

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

labelled_functions/decorators.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from labelled_functions.abstract import Unknown
66
from labelled_functions.labels import label, LabelledFunction
7+
from copy import copy
78

89

910
# API
@@ -49,22 +50,15 @@ def timed_func(*args, **kwargs):
4950
return timed_func
5051

5152

52-
# def cache(func, memory=None):
53-
# """Return a copy of func but with results cached with joblib.
53+
def decorate(func, decorator):
54+
new_func = copy(func)
55+
new_func.function = decorator(func.function)
56+
return new_func
5457

55-
# TODO: When used on a pipeline, do not merge the pipeline into a single function.
56-
# """
57-
# from labelled_functions.labels import LabelledFunction
5858

59-
# if memory is None:
60-
# from joblib import Memory
61-
# memory = Memory("/tmp", verbose=0)
62-
63-
# return LabelledFunction(
64-
# memory.cache(func),
65-
# name=func.name,
66-
# input_names=func.input_names,
67-
# output_names=func.output_names,
68-
# default_values=func.default_values,
69-
# )
59+
def cache(func, memory=None):
60+
if memory is None:
61+
from joblib import Memory
62+
memory = Memory("/tmp", verbose=0)
63+
return decorate(func, memory.cache)
7064

test/test_decorators.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import pytest
55

6-
from labelled_functions.decorators import keeping_inputs
6+
from labelled_functions import label, pipeline
7+
from labelled_functions.decorators import keeping_inputs, decorate, cache
78
from example_functions import *
89

910

@@ -26,3 +27,16 @@ def test_keeping_inputs():
2627
with pytest.raises(TypeError):
2728
keeping_inputs(all_kinds_of_args)(0, 1, 2, 3)
2829

30+
31+
def test_cache():
32+
# Cache one function
33+
pipe = cache(label(random_radius)) | label(cylinder_volume)
34+
a = pipe(length=1.0)
35+
b = pipe(length=1.0)
36+
assert a == b
37+
38+
# Cache whole function
39+
pipe = cache(pipeline([random_radius, cylinder_volume]))
40+
a = pipe(length=1.0)
41+
b = pipe(length=1.0)
42+
assert a == b

0 commit comments

Comments
 (0)