-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #81 from opendilab/dev/int
dev(hansbug): add generic_flatten, generic_unflatten and register_integrate_container for integration module
- Loading branch information
Showing
6 changed files
with
440 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from collections import namedtuple | ||
|
||
import pytest | ||
from easydict import EasyDict | ||
|
||
from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping | ||
|
||
nt = namedtuple('nt', ['a', 'b']) | ||
|
||
|
||
class MyTreeValue(FastTreeValue): | ||
pass | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestTreeIntegrationGeneral: | ||
def test_general_flatten_and_unflatten(self): | ||
demo_data = { | ||
'a': 1, | ||
'b': [2, 3, 'f'], | ||
'c': (2, 5, 'ds', EasyDict({ | ||
'x': None, | ||
'z': [34, '1.2'], | ||
})), | ||
'd': nt('f', 100), | ||
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) | ||
} | ||
v, spec = generic_flatten(demo_data) | ||
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] | ||
|
||
rv = generic_unflatten(v, spec) | ||
assert rv == demo_data | ||
assert isinstance(rv['c'][-1], EasyDict) | ||
assert isinstance(rv['d'], nt) | ||
assert isinstance(rv['c'][-1]['z'], list) | ||
assert isinstance(rv['e'], MyTreeValue) | ||
|
||
def test_register_my_class(self): | ||
class MyDC: | ||
def __init__(self, x, y): | ||
self.x = x | ||
self.y = y | ||
|
||
def __eq__(self, other): | ||
return isinstance(other, MyDC) and self.x == other.x and self.y == other.y | ||
|
||
def _mydc_flatten(v): | ||
return [v.x, v.y], MyDC | ||
|
||
def _mydc_unflatten(v, spec): | ||
return spec(*v) | ||
|
||
register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) | ||
|
||
demo_data = { | ||
'a': 1, | ||
'b': [2, 3, 'f'], | ||
'c': (2, 5, 'ds', EasyDict({ | ||
'x': None, | ||
'z': MyDC(34, '1.2'), | ||
})), | ||
'd': nt('f', 100), | ||
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) | ||
} | ||
v, spec = generic_flatten(demo_data) | ||
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] | ||
|
||
rv = generic_unflatten(v, spec) | ||
assert rv == demo_data | ||
assert isinstance(rv['c'][-1], EasyDict) | ||
assert isinstance(rv['d'], nt) | ||
assert isinstance(rv['c'][-1]['z'], MyDC) | ||
assert isinstance(rv['e'], MyTreeValue) | ||
|
||
def test_generic_mapping(self): | ||
demo_data = { | ||
'a': 1, | ||
'b': [2, 3, 'f'], | ||
'c': (2, 5, 'ds', EasyDict({ | ||
'x': None, | ||
'z': (34, '1.2'), | ||
})), | ||
'd': nt('f', 100), | ||
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) | ||
} | ||
assert generic_mapping(demo_data, str) == { | ||
'a': '1', | ||
'b': ['2', '3', 'f'], | ||
'c': ('2', '5', 'ds', EasyDict({ | ||
'x': 'None', | ||
'z': ('34', '1.2'), | ||
})), | ||
'd': nt('f', '100'), | ||
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'}) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
from libcpp cimport bool | ||
|
||
cdef tuple _dict_flatten(object d) | ||
cdef object _dict_unflatten(list values, tuple spec) | ||
|
||
cdef tuple _list_and_tuple_flatten(object l) | ||
cdef object _list_and_tuple_unflatten(list values, object spec) | ||
|
||
cdef tuple _namedtuple_flatten(object l) | ||
cdef object _namedtuple_unflatten(list values, object spec) | ||
|
||
cdef tuple _treevalue_flatten(object l) | ||
cdef object _treevalue_unflatten(list values, tuple spec) | ||
|
||
cdef bool _is_namedtuple_instance(pytree) except* | ||
|
||
cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except* | ||
|
||
cdef tuple _c_get_flatted_values_and_spec(object v) | ||
cdef object _c_get_object_from_flatted(object values, object type_, object spec) | ||
|
||
cpdef object generic_flatten(object v) | ||
cpdef object generic_unflatten(object v, tuple gspec) | ||
cpdef object generic_mapping(object v, object func) |
Oops, something went wrong.