Skip to content

Commit fbb7ad1

Browse files
authored
Merge pull request #81 from opendilab/dev/int
dev(hansbug): add generic_flatten, generic_unflatten and register_integrate_container for integration module
2 parents c9a27c3 + dd9ff51 commit fbb7ad1

File tree

6 files changed

+440
-1
lines changed

6 files changed

+440
-1
lines changed

docs/source/api_doc/tree/integration.rst

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,36 @@ register_treevalue_class
2727

2828
.. autofunction:: register_treevalue_class
2929

30+
31+
.. _apidoc_tree_integration_register_integrate_container:
32+
33+
register_integrate_container
34+
--------------------------------
35+
36+
.. autofunction:: register_integrate_container
37+
38+
39+
.. _apidoc_tree_integration_generic_flatten:
40+
41+
generic_flatten
42+
--------------------------------
43+
44+
.. autofunction:: generic_flatten
45+
46+
47+
.. _apidoc_tree_integration_generic_unflatten:
48+
49+
generic_unflatten
50+
--------------------------------
51+
52+
.. autofunction:: generic_unflatten
53+
54+
55+
.. _apidoc_tree_integration_generic_mapping:
56+
57+
generic_mapping
58+
--------------------------------
59+
60+
.. autofunction:: generic_mapping
61+
62+

test/tree/integration/test_general.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
from collections import namedtuple
2+
3+
import pytest
4+
from easydict import EasyDict
5+
6+
from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping
7+
8+
nt = namedtuple('nt', ['a', 'b'])
9+
10+
11+
class MyTreeValue(FastTreeValue):
12+
pass
13+
14+
15+
@pytest.mark.unittest
16+
class TestTreeIntegrationGeneral:
17+
def test_general_flatten_and_unflatten(self):
18+
demo_data = {
19+
'a': 1,
20+
'b': [2, 3, 'f'],
21+
'c': (2, 5, 'ds', EasyDict({
22+
'x': None,
23+
'z': [34, '1.2'],
24+
})),
25+
'd': nt('f', 100),
26+
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
27+
}
28+
v, spec = generic_flatten(demo_data)
29+
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
30+
31+
rv = generic_unflatten(v, spec)
32+
assert rv == demo_data
33+
assert isinstance(rv['c'][-1], EasyDict)
34+
assert isinstance(rv['d'], nt)
35+
assert isinstance(rv['c'][-1]['z'], list)
36+
assert isinstance(rv['e'], MyTreeValue)
37+
38+
def test_register_my_class(self):
39+
class MyDC:
40+
def __init__(self, x, y):
41+
self.x = x
42+
self.y = y
43+
44+
def __eq__(self, other):
45+
return isinstance(other, MyDC) and self.x == other.x and self.y == other.y
46+
47+
def _mydc_flatten(v):
48+
return [v.x, v.y], MyDC
49+
50+
def _mydc_unflatten(v, spec):
51+
return spec(*v)
52+
53+
register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten)
54+
55+
demo_data = {
56+
'a': 1,
57+
'b': [2, 3, 'f'],
58+
'c': (2, 5, 'ds', EasyDict({
59+
'x': None,
60+
'z': MyDC(34, '1.2'),
61+
})),
62+
'd': nt('f', 100),
63+
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
64+
}
65+
v, spec = generic_flatten(demo_data)
66+
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]
67+
68+
rv = generic_unflatten(v, spec)
69+
assert rv == demo_data
70+
assert isinstance(rv['c'][-1], EasyDict)
71+
assert isinstance(rv['d'], nt)
72+
assert isinstance(rv['c'][-1]['z'], MyDC)
73+
assert isinstance(rv['e'], MyTreeValue)
74+
75+
def test_generic_mapping(self):
76+
demo_data = {
77+
'a': 1,
78+
'b': [2, 3, 'f'],
79+
'c': (2, 5, 'ds', EasyDict({
80+
'x': None,
81+
'z': (34, '1.2'),
82+
})),
83+
'd': nt('f', 100),
84+
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
85+
}
86+
assert generic_mapping(demo_data, str) == {
87+
'a': '1',
88+
'b': ['2', '3', 'f'],
89+
'c': ('2', '5', 'ds', EasyDict({
90+
'x': 'None',
91+
'z': ('34', '1.2'),
92+
})),
93+
'd': nt('f', '100'),
94+
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'})
95+
}

treevalue/tree/integration/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Type
22

3+
from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
34
from .jax import register_for_jax
45
from .torch import register_for_torch
56
from ..tree import TreeValue

treevalue/tree/integration/base.pyx

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv):
1414
values.append(value)
1515

1616
return values, (type(tv), paths)
17-
pass
1817

1918
cdef inline object _c_unflatten_for_integration(object values, tuple spec):
2019
cdef object type_
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# distutils:language=c++
2+
# cython:language_level=3
3+
4+
from libcpp cimport bool
5+
6+
cdef tuple _dict_flatten(object d)
7+
cdef object _dict_unflatten(list values, tuple spec)
8+
9+
cdef tuple _list_and_tuple_flatten(object l)
10+
cdef object _list_and_tuple_unflatten(list values, object spec)
11+
12+
cdef tuple _namedtuple_flatten(object l)
13+
cdef object _namedtuple_unflatten(list values, object spec)
14+
15+
cdef tuple _treevalue_flatten(object l)
16+
cdef object _treevalue_unflatten(list values, tuple spec)
17+
18+
cdef bool _is_namedtuple_instance(pytree) except*
19+
20+
cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*
21+
22+
cdef tuple _c_get_flatted_values_and_spec(object v)
23+
cdef object _c_get_object_from_flatted(object values, object type_, object spec)
24+
25+
cpdef object generic_flatten(object v)
26+
cpdef object generic_unflatten(object v, tuple gspec)
27+
cpdef object generic_mapping(object v, object func)

0 commit comments

Comments
 (0)