Skip to content

dev(hansbug): add generic_flatten, generic_unflatten and register_integrate_container for integration module #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions docs/source/api_doc/tree/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,36 @@ register_treevalue_class

.. autofunction:: register_treevalue_class


.. _apidoc_tree_integration_register_integrate_container:

register_integrate_container
--------------------------------

.. autofunction:: register_integrate_container


.. _apidoc_tree_integration_generic_flatten:

generic_flatten
--------------------------------

.. autofunction:: generic_flatten


.. _apidoc_tree_integration_generic_unflatten:

generic_unflatten
--------------------------------

.. autofunction:: generic_unflatten


.. _apidoc_tree_integration_generic_mapping:

generic_mapping
--------------------------------

.. autofunction:: generic_mapping


95 changes: 95 additions & 0 deletions test/tree/integration/test_general.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from collections import namedtuple

import pytest
from easydict import EasyDict

from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping

nt = namedtuple('nt', ['a', 'b'])


class MyTreeValue(FastTreeValue):
pass


@pytest.mark.unittest
class TestTreeIntegrationGeneral:
def test_general_flatten_and_unflatten(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': [34, '1.2'],
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], list)
assert isinstance(rv['e'], MyTreeValue)

def test_register_my_class(self):
class MyDC:
def __init__(self, x, y):
self.x = x
self.y = y

def __eq__(self, other):
return isinstance(other, MyDC) and self.x == other.x and self.y == other.y

def _mydc_flatten(v):
return [v.x, v.y], MyDC

def _mydc_unflatten(v, spec):
return spec(*v)

register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten)

demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': MyDC(34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
v, spec = generic_flatten(demo_data)
assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']]

rv = generic_unflatten(v, spec)
assert rv == demo_data
assert isinstance(rv['c'][-1], EasyDict)
assert isinstance(rv['d'], nt)
assert isinstance(rv['c'][-1]['z'], MyDC)
assert isinstance(rv['e'], MyTreeValue)

def test_generic_mapping(self):
demo_data = {
'a': 1,
'b': [2, 3, 'f'],
'c': (2, 5, 'ds', EasyDict({
'x': None,
'z': (34, '1.2'),
})),
'd': nt('f', 100),
'e': MyTreeValue({'x': 1, 'y': 'dsfljk'})
}
assert generic_mapping(demo_data, str) == {
'a': '1',
'b': ['2', '3', 'f'],
'c': ('2', '5', 'ds', EasyDict({
'x': 'None',
'z': ('34', '1.2'),
})),
'd': nt('f', '100'),
'e': MyTreeValue({'x': '1', 'y': 'dsfljk'})
}
1 change: 1 addition & 0 deletions treevalue/tree/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Type

from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping
from .jax import register_for_jax
from .torch import register_for_torch
from ..tree import TreeValue
Expand Down
1 change: 0 additions & 1 deletion treevalue/tree/integration/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv):
values.append(value)

return values, (type(tv), paths)
pass

cdef inline object _c_unflatten_for_integration(object values, tuple spec):
cdef object type_
Expand Down
27 changes: 27 additions & 0 deletions treevalue/tree/integration/general.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# distutils:language=c++
# cython:language_level=3

from libcpp cimport bool

cdef tuple _dict_flatten(object d)
cdef object _dict_unflatten(list values, tuple spec)

cdef tuple _list_and_tuple_flatten(object l)
cdef object _list_and_tuple_unflatten(list values, object spec)

cdef tuple _namedtuple_flatten(object l)
cdef object _namedtuple_unflatten(list values, object spec)

cdef tuple _treevalue_flatten(object l)
cdef object _treevalue_unflatten(list values, tuple spec)

cdef bool _is_namedtuple_instance(pytree) except*

cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*

cdef tuple _c_get_flatted_values_and_spec(object v)
cdef object _c_get_object_from_flatted(object values, object type_, object spec)

cpdef object generic_flatten(object v)
cpdef object generic_unflatten(object v, tuple gspec)
cpdef object generic_mapping(object v, object func)
Loading