Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(hansbug): fix bug of #82, add more unittests #83

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix(hansbug): fix bug of #82, add more unittests
  • Loading branch information
HansBug committed Mar 6, 2023
commit 80d4e9b59c14bfe0e8df157f755c6f4589f983ee
27 changes: 27 additions & 0 deletions test/tree/func/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import pytest

from treevalue import FastTreeValue
from treevalue.tree import func_treelize, TreeValue, method_treelize, classmethod_treelize, delayed


Expand Down Expand Up @@ -401,3 +402,29 @@ def total(a, b):
'v': {'a': 12, 'b': 25, 'x': {'c': 38, 'd': 51}},
})
assert cnt_1 == 4

def test_return_treevalue(self):
def func(x):
return FastTreeValue({
'x': x, 'y': x ** 2,
})

f = FastTreeValue({
'x': func,
'y': {
'z': func,
}
})
v = FastTreeValue({'x': 2, 'y': {'z': 34}})
assert f(v) == FastTreeValue({
'x': {
'x': v.x,
'y': v.x ** 2,
},
'y': {
'z': {
'x': v.y.z,
'y': v.y.z ** 2,
}
}
})
7 changes: 6 additions & 1 deletion treevalue/tree/func/cfunc.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cdef object _c_func_treelize_run(object func, list args, dict kwargs, _e_tree_mo

cdef list _a_args
cdef dict _a_kwargs
cdef object _a_ret
if not has_tree:
_a_args = []
for v in args:
Expand All @@ -72,7 +73,11 @@ cdef object _c_func_treelize_run(object func, list args, dict kwargs, _e_tree_mo
else:
_a_kwargs[k] = missing_func()

return func(*_a_args, **_a_kwargs)
_a_ret = func(*_a_args, **_a_kwargs)
if isinstance(_a_ret, TreeValue):
return _a_ret._detach()
else:
return _a_ret

cdef dict _d_res = {}
cdef str ak
Expand Down