Skip to content

Commit

Permalink
Merge pull request #81 from opendilab/dev/int
Browse files Browse the repository at this point in the history
dev(hansbug): add generic_flatten, generic_unflatten and register_integrate_container for integration module
  • Loading branch information
HansBug authored Feb 27, 2023
2 parents c9a27c3 + dd9ff51 commit fbb7ad1
Show file tree
Hide file tree
Showing 6 changed files with 440 additions and 1 deletion.
33 changes: 33 additions & 0 deletions docs/source/api_doc/tree/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,36 @@ register_treevalue_class

.. autofunction:: register_treevalue_class


.. _apidoc_tree_integration_register_integrate_container:

register_integrate_container
--------------------------------

.. autofunction:: register_integrate_container


.. _apidoc_tree_integration_generic_flatten:

generic_flatten
--------------------------------

.. autofunction:: generic_flatten


.. _apidoc_tree_integration_generic_unflatten:

generic_unflatten
--------------------------------

.. autofunction:: generic_unflatten


.. _apidoc_tree_integration_generic_mapping:

generic_mapping
--------------------------------

.. autofunction:: generic_mapping


95 changes: 95 additions & 0 deletions test/tree/integration/test_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections import namedtuple

import pytest
from easydict import EasyDict

from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping

nt = namedtuple('nt', ['a', 'b'])


class MyTreeValue(FastTreeValue):
pass


@pytest.mark.unittest
class TestTreeIntegrationGeneral:
def test_general_flatten_and_unflatten(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': [34, '1.2'],
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], list)
assert isinstance(rv['e'], MyTreeValue)

def test_register_my_class(self):
class MyDC:
def __init__(self, x, y):
self.x = x
self.y = y

def __eq__(self, other):
return isinstance(other, MyDC) and self.x == other.x and self.y == other.y

def _mydc_flatten(v):
return [v.x, v.y], MyDC

def _mydc_unflatten(v, spec):
return spec(*v)

register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten)

demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': MyDC(34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], MyDC)
assert isinstance(rv['e'], MyTreeValue)

def test_generic_mapping(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': (34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
assert generic_mapping(demo_data, str) == {
'a': '1',
'b': ['2', '3', 'f'],
'c': ('2', '5', 'ds', EasyDict({
'x': 'None',
'z': ('34', '1.2'),
})),
'd': nt('f', '100'),
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'})
}
1 change: 1 addition & 0 deletions treevalue/tree/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Type

from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
from .jax import register_for_jax
from .torch import register_for_torch
from ..tree import TreeValue
Expand Down
1 change: 0 additions & 1 deletion treevalue/tree/integration/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv):
values.append(value)

return values, (type(tv), paths)
pass

cdef inline object _c_unflatten_for_integration(object values, tuple spec):
cdef object type_
Expand Down
27 changes: 27 additions & 0 deletions treevalue/tree/integration/general.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# distutils:language=c++
# cython:language_level=3

from libcpp cimport bool

cdef tuple _dict_flatten(object d)
cdef object _dict_unflatten(list values, tuple spec)

cdef tuple _list_and_tuple_flatten(object l)
cdef object _list_and_tuple_unflatten(list values, object spec)

cdef tuple _namedtuple_flatten(object l)
cdef object _namedtuple_unflatten(list values, object spec)

cdef tuple _treevalue_flatten(object l)
cdef object _treevalue_unflatten(list values, tuple spec)

cdef bool _is_namedtuple_instance(pytree) except*

cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*

cdef tuple _c_get_flatted_values_and_spec(object v)
cdef object _c_get_object_from_flatted(object values, object type_, object spec)

cpdef object generic_flatten(object v)
cpdef object generic_unflatten(object v, tuple gspec)
cpdef object generic_mapping(object v, object func)
Loading

0 comments on commit fbb7ad1

Please sign in to comment.