From a504c0e637699ddcdc75c63d9297bbe848e754a4 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 27 Feb 2023 17:46:36 +0800 Subject: [PATCH 1/3] dev(hansbug): add generic_flatten and generic_unflatten --- docs/source/api_doc/tree/integration.rst | 25 +++ test/tree/integration/test_general.py | 81 +++++++ treevalue/tree/integration/__init__.py | 1 + treevalue/tree/integration/base.pyx | 1 - treevalue/tree/integration/general.pxd | 26 +++ treevalue/tree/integration/general.pyx | 259 +++++++++++++++++++++++ 6 files changed, 392 insertions(+), 1 deletion(-) create mode 100644 test/tree/integration/test_general.py create mode 100644 treevalue/tree/integration/general.pxd create mode 100644 treevalue/tree/integration/general.pyx diff --git a/docs/source/api_doc/tree/integration.rst b/docs/source/api_doc/tree/integration.rst index 1e8bebf6af..46d8377eed 100644 --- a/docs/source/api_doc/tree/integration.rst +++ b/docs/source/api_doc/tree/integration.rst @@ -27,3 +27,28 @@ register_treevalue_class .. autofunction:: register_treevalue_class + +.. _apidoc_tree_integration_register_integrate_container: + +register_integrate_container +-------------------------------- + +.. autofunction:: register_integrate_container + + +.. _apidoc_tree_integration_generic_flatten: + +generic_flatten +-------------------------------- + +.. autofunction:: generic_flatten + + +.. _apidoc_tree_integration_generic_unflatten: + +generic_unflatten +-------------------------------- + +.. autofunction:: generic_unflatten + + diff --git a/test/tree/integration/test_general.py b/test/tree/integration/test_general.py new file mode 100644 index 0000000000..0fedc3aef8 --- /dev/null +++ b/test/tree/integration/test_general.py @@ -0,0 +1,81 @@ +from collections import namedtuple +from dataclasses import dataclass + +import pytest +from easydict import EasyDict + +from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container + + +@dataclass +class DC: + x: int + y: str + + +nt = namedtuple('nt', ['a', 'b']) + + +class MyTreeValue(FastTreeValue): + pass + + +@pytest.mark.unittest +class TestTreeIntegrationGeneral: + def test_general_flatten_and_unflatten(self): + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': DC(34, '1.2'), + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + v, spec = generic_flatten(demo_data) + assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + + rv = generic_unflatten(v, spec) + assert rv == demo_data + assert isinstance(rv['c'][-1], EasyDict) + assert isinstance(rv['d'], nt) + assert isinstance(rv['c'][-1]['z'], DC) + assert isinstance(rv['e'], MyTreeValue) + + def test_register_my_class(self): + class MyDC: + def __init__(self, x, y): + self.x = x + self.y = y + + def __eq__(self, other): + return isinstance(other, MyDC) and self.x == other.x and self.y == other.y + + def _mydc_flatten(v): + return [v.x, v.y], MyDC + + def _mydc_unflatten(v, spec): + return spec(*v) + + register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) + + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': MyDC(34, '1.2'), + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + v, spec = generic_flatten(demo_data) + assert v == [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + + rv = generic_unflatten(v, spec) + assert rv == demo_data + assert isinstance(rv['c'][-1], EasyDict) + assert isinstance(rv['d'], nt) + assert isinstance(rv['c'][-1]['z'], MyDC) + assert isinstance(rv['e'], MyTreeValue) diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py index 11231ab66b..af957b9eb7 100644 --- a/treevalue/tree/integration/__init__.py +++ b/treevalue/tree/integration/__init__.py @@ -1,5 +1,6 @@ from typing import Type +from .general import generic_flatten, generic_unflatten, register_integrate_container from .jax import register_for_jax from .torch import register_for_torch from ..tree import TreeValue diff --git a/treevalue/tree/integration/base.pyx b/treevalue/tree/integration/base.pyx index 43bb761eb6..8ac5f40674 100644 --- a/treevalue/tree/integration/base.pyx +++ b/treevalue/tree/integration/base.pyx @@ -14,7 +14,6 @@ cdef inline tuple _c_flatten_for_integration(object tv): values.append(value) return values, (type(tv), paths) - pass cdef inline object _c_unflatten_for_integration(object values, tuple spec): cdef object type_ diff --git a/treevalue/tree/integration/general.pxd b/treevalue/tree/integration/general.pxd new file mode 100644 index 0000000000..d79327a851 --- /dev/null +++ b/treevalue/tree/integration/general.pxd @@ -0,0 +1,26 @@ +# distutils:language=c++ +# cython:language_level=3 + +from libcpp cimport bool + +cdef tuple _dict_flatten(object d) +cdef object _dict_unflatten(list values, tuple spec) + +cdef tuple _list_and_tuple_flatten(object l) +cdef object _list_and_tuple_unflatten(list values, object spec) + +cdef tuple _namedtuple_flatten(object l) +cdef object _namedtuple_unflatten(list values, object spec) + +cdef tuple _dataclass_flatten(object l) +cdef object _dataclass_unflatten(list values, tuple spec) + +cdef tuple _treevalue_flatten(object l) +cdef object _treevalue_unflatten(list values, tuple spec) + +cdef bool _is_namedtuple_instance(pytree) except* + +cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except* + +cpdef object generic_flatten(object v) +cpdef object generic_unflatten(object v, tuple gspec) diff --git a/treevalue/tree/integration/general.pyx b/treevalue/tree/integration/general.pyx new file mode 100644 index 0000000000..3a593d633f --- /dev/null +++ b/treevalue/tree/integration/general.pyx @@ -0,0 +1,259 @@ +# distutils:language=c++ +# cython:language_level=3 + +from collections import namedtuple +from dataclasses import dataclass, is_dataclass + +import cython +from libcpp cimport bool + +from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration +from ..tree.tree cimport TreeValue + +_REGISTERED_CONTAINERS = {} + +cdef inline tuple _dict_flatten(object d): + cdef list values = [] + cdef list keys = [] + + cdef object key, value + for key, value in d.items(): + keys.append(key) + values.append(value) + + return values, (type(d), keys) + +cdef inline object _dict_unflatten(list values, tuple spec): + cdef object type_ + cdef list keys + type_, keys = spec + + cdef dict retval = {} + for key, value in zip(keys, values): + retval[key] = value + + return type_(retval) + +cdef inline tuple _list_and_tuple_flatten(object l): + return list(l), type(l) + +cdef inline object _list_and_tuple_unflatten(list values, object spec): + return spec(values) + +cdef inline tuple _namedtuple_flatten(object l): + return list(l), type(l) + +cdef inline object _namedtuple_unflatten(list values, object spec): + return spec(*values) + +cdef inline tuple _dataclass_flatten(object l): + cdef object type_ = type(l) + cdef list keys = [] + cdef list values = [] + for key in type_.__dataclass_fields__.keys(): + keys.append(key) + values.append(getattr(l, key)) + + return values, (type_, keys) + +cdef inline object _dataclass_unflatten(list values, tuple spec): + cdef object type_ + cdef list keys + type_, keys = spec + + return type_(**{key: value for key, value in zip(keys, values)}) + +cdef inline tuple _treevalue_flatten(object l): + return _c_flatten_for_integration(l) + +cdef inline object _treevalue_unflatten(list values, tuple spec): + return _c_unflatten_for_integration(values, spec) + +cdef inline bool _is_namedtuple_instance(pytree) except*: + cdef object typ = type(pytree) + cdef tuple bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + + fields = getattr(typ, '_fields', None) + if not isinstance(fields, tuple): + return False # pragma: no cover + + return all(type(entry) == str for entry in fields) + +@cython.binding(True) +cpdef inline void register_integrate_container(object type_, object flatten_func, object unflatten_func) except*: + """ + Overview: + Register custom data class for generic flatten and unflatten. + + :param type_: Class of data to be registered. + :param flatten_func: Function for flattening. + :param unflatten_func: Function for unflattening. + + Examples:: + + >>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten + >>> + >>> class MyDC: + ... def __init__(self, x, y): + ... self.x = x + ... self.y = y + ... + ... def __eq__(self, other): + ... return isinstance(other, MyDC) and self.x == other.x and self.y == other.y + >>> + >>> def _mydc_flatten(v): + ... return [v.x, v.y], MyDC + >>> + >>> def _mydc_unflatten(v, spec): # spec will be MyDC + ... return spec(*v) + + >>> + >>> register_integrate_container(MyDC, _mydc_flatten, _mydc_unflatten) # register MyDC + >>> + >>> v, spec = generic_flatten({'a': MyDC(2, 3), 'b': MyDC((4, 5), FastTreeValue({'x': 1, 'y': 'f'}))}) + >>> v + [[2, 3], [[4, 5], [1, 'f']]] + >>> + >>> rt=generic_unflatten(v, spec) + >>> rt + {'a': <__main__.MyDC object at 0x7fbda613f9d0>, 'b': <__main__.MyDC object at 0x7fbda6148150>} + >>> rt['a'].x + 2 + >>> rt['a'].y + 3 + >>> rt['b'].x + (4, 5) + >>> rt['b'].y + + ├── 'x' --> 1 + └── 'y' --> 'f' + """ + _REGISTERED_CONTAINERS[type_] = (flatten_func, unflatten_func) + +@cython.binding(True) +cpdef inline object generic_flatten(object v): + """ + Overview: + Flatten generic data, including native objects, ``TreeValue``, namedtuples and dataclasses. + + :param v: Value to be flatted. + :return: Flatted value. + + Examples:: + >>> from collections import namedtuple + >>> from dataclasses import dataclass + >>> from easydict import EasyDict + >>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten + >>> + >>> class MyTreeValue(FastTreeValue): + ... pass + >>> + >>> @dataclass + ... class DC: + ... x: int + ... y: float + ... + ... def __repr__(self): + ... return f'DC({self.x}, {self.y})' + >>> + >>> nt = namedtuple('nt', ['a', 'b']) + >>> + >>> origin = { + ... 'a': 1, + ... 'b': [2, 3, 'f', ], + ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class + ... 'x': None, + ... 'z': DC(34, '1.2'), # dataclass + ... })), + ... 'd': nt('f', 100), # namedtuple + ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue + ... } + >>> v, spec = generic_flatten(origin) + >>> v + [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] + >>> + >>> rv = generic_unflatten(v, spec) + >>> rv # all the data, including types, are recovered + {'a': 1, 'b': [2, 3, 'f'], 'c': (2, 5, 'ds', {'x': None, 'z': DC(34, 1.2)}), 'd': nt(a='f', b=100), 'e': + ├── 'x' --> 1 + └── 'y' --> 'dsfljk' + } + >>> type(rv['c'][-1]) + + """ + cdef list values + cdef object spec, type_ + cdef object flatten_func + if isinstance(v, dict): + values, spec = _dict_flatten(v) + type_ = dict + elif _is_namedtuple_instance(v): + values, spec = _namedtuple_flatten(v) + type_ = namedtuple + elif isinstance(v, (list, tuple)): + values, spec = _list_and_tuple_flatten(v) + type_ = list + elif is_dataclass(v): + values, spec = _dataclass_flatten(v) + type_ = dataclass + elif isinstance(v, TreeValue): + values, spec = _treevalue_flatten(v) + type_ = TreeValue + elif type(v) in _REGISTERED_CONTAINERS: + flatten_func, _ = _REGISTERED_CONTAINERS[type(v)] + values, spec = flatten_func(v) + type_ = type(v) + else: + return v, (None, None, None) + + cdef list child_values = [] + cdef list child_specs = [] + cdef object value, cval, cspec + for value in values: + cval, cspec = generic_flatten(value) + child_values.append(cval) + child_specs.append(cspec) + + return child_values, (type_, spec, child_specs) + +@cython.binding(True) +cpdef inline object generic_unflatten(object v, tuple gspec): + """ + Overview: + Inverse operation of :func:`generic_flatten`. + + :param v: Flatted values. + :param gspec: Spec data of original object. + + Examples:: + See :func:`generic_flatten`. + """ + cdef object type_, spec + cdef list child_specs + type_, spec, child_specs = gspec + if type_ is None: + return v + + cdef list values = [] + cdef object _i_value, _i_spec + for _i_value, _i_spec in zip(v, child_specs): + values.append(generic_unflatten(_i_value, _i_spec)) + + cdef object unflatten_func + if type_ is dict: + return _dict_unflatten(values, spec) + elif type_ is namedtuple: + return _namedtuple_unflatten(values, spec) + elif type_ is list: + return _list_and_tuple_unflatten(values, spec) + elif type_ is dataclass: + return _dataclass_unflatten(values, spec) + elif type_ is TreeValue: + return _treevalue_unflatten(values, spec) + elif type_ in _REGISTERED_CONTAINERS: + _, unflatten_func = _REGISTERED_CONTAINERS[type_] + return unflatten_func(values, spec) + else: + raise TypeError(f'Unknown type for unflatten - {values!r}, {gspec!r}.') # pragma: no cover From c38bf2339f7b707faa5443820a6a224033fc165f Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 27 Feb 2023 19:00:04 +0800 Subject: [PATCH 2/3] dev(hansbug): add generic_mapping function --- test/tree/integration/test_general.py | 36 +++++-- treevalue/tree/integration/__init__.py | 2 +- treevalue/tree/integration/general.pxd | 7 +- treevalue/tree/integration/general.pyx | 136 ++++++++++++------------- 4 files changed, 94 insertions(+), 87 deletions(-) diff --git a/test/tree/integration/test_general.py b/test/tree/integration/test_general.py index 0fedc3aef8..7c959b82f7 100644 --- a/test/tree/integration/test_general.py +++ b/test/tree/integration/test_general.py @@ -1,17 +1,9 @@ from collections import namedtuple -from dataclasses import dataclass import pytest from easydict import EasyDict -from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container - - -@dataclass -class DC: - x: int - y: str - +from treevalue import generic_flatten, generic_unflatten, FastTreeValue, register_integrate_container, generic_mapping nt = namedtuple('nt', ['a', 'b']) @@ -28,7 +20,7 @@ def test_general_flatten_and_unflatten(self): 'b': [2, 3, 'f'], 'c': (2, 5, 'ds', EasyDict({ 'x': None, - 'z': DC(34, '1.2'), + 'z': [34, '1.2'], })), 'd': nt('f', 100), 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) @@ -40,7 +32,7 @@ def test_general_flatten_and_unflatten(self): assert rv == demo_data assert isinstance(rv['c'][-1], EasyDict) assert isinstance(rv['d'], nt) - assert isinstance(rv['c'][-1]['z'], DC) + assert isinstance(rv['c'][-1]['z'], list) assert isinstance(rv['e'], MyTreeValue) def test_register_my_class(self): @@ -79,3 +71,25 @@ def _mydc_unflatten(v, spec): assert isinstance(rv['d'], nt) assert isinstance(rv['c'][-1]['z'], MyDC) assert isinstance(rv['e'], MyTreeValue) + + def test_generic_mapping(self): + demo_data = { + 'a': 1, + 'b': [2, 3, 'f'], + 'c': (2, 5, 'ds', EasyDict({ + 'x': None, + 'z': (34, '1.2'), + })), + 'd': nt('f', 100), + 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) + } + assert generic_mapping(demo_data, str) == { + 'a': '1', + 'b': ['2', '3', 'f'], + 'c': ('2', '5', 'ds', EasyDict({ + 'x': 'None', + 'z': ('34', '1.2'), + })), + 'd': nt('f', '100'), + 'e': MyTreeValue({'x': '1', 'y': 'dsfljk'}) + } diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py index 5e5a2d4220..7899b7242e 100644 --- a/treevalue/tree/integration/__init__.py +++ b/treevalue/tree/integration/__init__.py @@ -1,6 +1,6 @@ from typing import Type -from .general import generic_flatten, generic_unflatten, register_integrate_container +from .general import generic_flatten, generic_unflatten, register_integrate_container, generic_mapping from .jax import register_for_jax from .torch import register_for_torch from ..tree import TreeValue diff --git a/treevalue/tree/integration/general.pxd b/treevalue/tree/integration/general.pxd index d79327a851..dfb18b5c44 100644 --- a/treevalue/tree/integration/general.pxd +++ b/treevalue/tree/integration/general.pxd @@ -12,9 +12,6 @@ cdef object _list_and_tuple_unflatten(list values, object spec) cdef tuple _namedtuple_flatten(object l) cdef object _namedtuple_unflatten(list values, object spec) -cdef tuple _dataclass_flatten(object l) -cdef object _dataclass_unflatten(list values, tuple spec) - cdef tuple _treevalue_flatten(object l) cdef object _treevalue_unflatten(list values, tuple spec) @@ -22,5 +19,9 @@ cdef bool _is_namedtuple_instance(pytree) except* cpdef void register_integrate_container(object type_, object flatten_func, object unflatten_func) except* +cdef tuple _c_get_flatted_values_and_spec(object v) +cdef object _c_get_object_from_flatted(object values, object type_, object spec) + cpdef object generic_flatten(object v) cpdef object generic_unflatten(object v, tuple gspec) +cpdef object generic_mapping(object v, object func) diff --git a/treevalue/tree/integration/general.pyx b/treevalue/tree/integration/general.pyx index 3a593d633f..6338539cc2 100644 --- a/treevalue/tree/integration/general.pyx +++ b/treevalue/tree/integration/general.pyx @@ -2,7 +2,6 @@ # cython:language_level=3 from collections import namedtuple -from dataclasses import dataclass, is_dataclass import cython from libcpp cimport bool @@ -46,23 +45,6 @@ cdef inline tuple _namedtuple_flatten(object l): cdef inline object _namedtuple_unflatten(list values, object spec): return spec(*values) -cdef inline tuple _dataclass_flatten(object l): - cdef object type_ = type(l) - cdef list keys = [] - cdef list values = [] - for key in type_.__dataclass_fields__.keys(): - keys.append(key) - values.append(getattr(l, key)) - - return values, (type_, keys) - -cdef inline object _dataclass_unflatten(list values, tuple spec): - cdef object type_ - cdef list keys - type_, keys = spec - - return type_(**{key: value for key, value in zip(keys, values)}) - cdef inline tuple _treevalue_flatten(object l): return _c_flatten_for_integration(l) @@ -132,40 +114,73 @@ cpdef inline void register_integrate_container(object type_, object flatten_func """ _REGISTERED_CONTAINERS[type_] = (flatten_func, unflatten_func) +cdef inline tuple _c_get_flatted_values_and_spec(object v): + cdef list values + cdef object spec, type_ + cdef object flatten_func + if isinstance(v, dict): + values, spec = _dict_flatten(v) + type_ = dict + elif _is_namedtuple_instance(v): + values, spec = _namedtuple_flatten(v) + type_ = namedtuple + elif isinstance(v, (list, tuple)): + values, spec = _list_and_tuple_flatten(v) + type_ = list + elif isinstance(v, TreeValue): + values, spec = _treevalue_flatten(v) + type_ = TreeValue + elif type(v) in _REGISTERED_CONTAINERS: + flatten_func, _ = _REGISTERED_CONTAINERS[type(v)] + values, spec = flatten_func(v) + type_ = type(v) + else: + return v, None, None + + return values, type_, spec + +cdef inline object _c_get_object_from_flatted(object values, object type_, object spec): + cdef object unflatten_func + if type_ is dict: + return _dict_unflatten(values, spec) + elif type_ is namedtuple: + return _namedtuple_unflatten(values, spec) + elif type_ is list: + return _list_and_tuple_unflatten(values, spec) + elif type_ is TreeValue: + return _treevalue_unflatten(values, spec) + elif type_ in _REGISTERED_CONTAINERS: + _, unflatten_func = _REGISTERED_CONTAINERS[type_] + return unflatten_func(values, spec) + else: + raise TypeError(f'Unknown type for unflatten - {values!r}, {spec!r}.') # pragma: no cover + @cython.binding(True) cpdef inline object generic_flatten(object v): """ Overview: - Flatten generic data, including native objects, ``TreeValue``, namedtuples and dataclasses. + Flatten generic data, including native objects, ``TreeValue``, namedtuples. :param v: Value to be flatted. :return: Flatted value. Examples:: + >>> from collections import namedtuple - >>> from dataclasses import dataclass >>> from easydict import EasyDict >>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten >>> >>> class MyTreeValue(FastTreeValue): ... pass >>> - >>> @dataclass - ... class DC: - ... x: int - ... y: float - ... - ... def __repr__(self): - ... return f'DC({self.x}, {self.y})' - >>> >>> nt = namedtuple('nt', ['a', 'b']) >>> >>> origin = { ... 'a': 1, - ... 'b': [2, 3, 'f', ], + ... 'b': (2, 3, 'f',), ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class ... 'x': None, - ... 'z': DC(34, '1.2'), # dataclass + ... 'z': [34, '1.2'], # dataclass ... })), ... 'd': nt('f', 100), # namedtuple ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue @@ -175,38 +190,17 @@ cpdef inline object generic_flatten(object v): [1, [2, 3, 'f'], [2, 5, 'ds', [None, [34, '1.2']]], ['f', 100], [1, 'dsfljk']] >>> >>> rv = generic_unflatten(v, spec) - >>> rv # all the data, including types, are recovered - {'a': 1, 'b': [2, 3, 'f'], 'c': (2, 5, 'ds', {'x': None, 'z': DC(34, 1.2)}), 'd': nt(a='f', b=100), 'e': + >>> rv + {'a': 1, 'b': (2, 3, 'f'), 'c': (2, 5, 'ds', {'x': None, 'z': [34, '1.2']}), 'd': nt(a='f', b=100), 'e': ├── 'x' --> 1 └── 'y' --> 'dsfljk' } >>> type(rv['c'][-1]) """ - cdef list values - cdef object spec, type_ - cdef object flatten_func - if isinstance(v, dict): - values, spec = _dict_flatten(v) - type_ = dict - elif _is_namedtuple_instance(v): - values, spec = _namedtuple_flatten(v) - type_ = namedtuple - elif isinstance(v, (list, tuple)): - values, spec = _list_and_tuple_flatten(v) - type_ = list - elif is_dataclass(v): - values, spec = _dataclass_flatten(v) - type_ = dataclass - elif isinstance(v, TreeValue): - values, spec = _treevalue_flatten(v) - type_ = TreeValue - elif type(v) in _REGISTERED_CONTAINERS: - flatten_func, _ = _REGISTERED_CONTAINERS[type(v)] - values, spec = flatten_func(v) - type_ = type(v) - else: - return v, (None, None, None) + values, type_, spec = _c_get_flatted_values_and_spec(v) + if type_ is None: + return values, (None, None, None) cdef list child_values = [] cdef list child_specs = [] @@ -241,19 +235,17 @@ cpdef inline object generic_unflatten(object v, tuple gspec): for _i_value, _i_spec in zip(v, child_specs): values.append(generic_unflatten(_i_value, _i_spec)) - cdef object unflatten_func - if type_ is dict: - return _dict_unflatten(values, spec) - elif type_ is namedtuple: - return _namedtuple_unflatten(values, spec) - elif type_ is list: - return _list_and_tuple_unflatten(values, spec) - elif type_ is dataclass: - return _dataclass_unflatten(values, spec) - elif type_ is TreeValue: - return _treevalue_unflatten(values, spec) - elif type_ in _REGISTERED_CONTAINERS: - _, unflatten_func = _REGISTERED_CONTAINERS[type_] - return unflatten_func(values, spec) - else: - raise TypeError(f'Unknown type for unflatten - {values!r}, {gspec!r}.') # pragma: no cover + return _c_get_object_from_flatted(values, type_, spec) + +@cython.binding(True) +cpdef inline object generic_mapping(object v, object func): + values, type_, spec = _c_get_flatted_values_and_spec(v) + if type_ is None: + return func(values) + + cdef list retvals = [] + cdef object value + for value in values: + retvals.append(generic_mapping(value, func)) + + return _c_get_object_from_flatted(retvals, type_, spec) From dd9ff51038b3b06f3fdce24e0d0fd3534cdb4ca0 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 27 Feb 2023 19:05:57 +0800 Subject: [PATCH 3/3] dev(hansbug): remove support for dataclasses --- docs/source/api_doc/tree/integration.rst | 8 +++++ treevalue/tree/integration/general.pyx | 39 ++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/docs/source/api_doc/tree/integration.rst b/docs/source/api_doc/tree/integration.rst index 46d8377eed..72afb9f89d 100644 --- a/docs/source/api_doc/tree/integration.rst +++ b/docs/source/api_doc/tree/integration.rst @@ -52,3 +52,11 @@ generic_unflatten .. autofunction:: generic_unflatten +.. _apidoc_tree_integration_generic_mapping: + +generic_mapping +-------------------------------- + +.. autofunction:: generic_mapping + + diff --git a/treevalue/tree/integration/general.pyx b/treevalue/tree/integration/general.pyx index 6338539cc2..a9467c774c 100644 --- a/treevalue/tree/integration/general.pyx +++ b/treevalue/tree/integration/general.pyx @@ -74,7 +74,6 @@ cpdef inline void register_integrate_container(object type_, object flatten_func :param unflatten_func: Function for unflattening. Examples:: - >>> from treevalue import register_integrate_container, generic_flatten, FastTreeValue, generic_unflatten >>> >>> class MyDC: @@ -159,13 +158,13 @@ cdef inline object _c_get_object_from_flatted(object values, object type_, objec cpdef inline object generic_flatten(object v): """ Overview: - Flatten generic data, including native objects, ``TreeValue``, namedtuples. + Flatten generic data, including native objects, ``TreeValue``, namedtuples and custom classes \ + (see :func:`register_integrate_container`). :param v: Value to be flatted. :return: Flatted value. Examples:: - >>> from collections import namedtuple >>> from easydict import EasyDict >>> from treevalue import FastTreeValue, generic_flatten, generic_unflatten @@ -239,6 +238,40 @@ cpdef inline object generic_unflatten(object v, tuple gspec): @cython.binding(True) cpdef inline object generic_mapping(object v, object func): + """ + Overview: + Generic map all the values, including native objects, ``TreeValue``, namedtuples and custom classes \ + (see :func:`register_integrate_container`) + + :param v: Original value, nested structure is supported. + :param func: Function to operate. + + Examples:: + >>> from collections import namedtuple + >>> from easydict import EasyDict + >>> from treevalue import FastTreeValue, generic_mapping + >>> + >>> class MyTreeValue(FastTreeValue): + ... pass + >>> + >>> nt = namedtuple('nt', ['a', 'b']) + >>> + >>> origin = { + ... 'a': 1, + ... 'b': (2, 3, 'f',), + ... 'c': (2, 5, 'ds', EasyDict({ # dict's child class + ... 'x': None, + ... 'z': [34, '1.2'], # dataclass + ... })), + ... 'd': nt('f', 100), # namedtuple + ... 'e': MyTreeValue({'x': 1, 'y': 'dsfljk'}) # treevalue + ... } + >>> generic_mapping(origin, str) + {'a': '1', 'b': ('2', '3', 'f'), 'c': ('2', '5', 'ds', {'x': 'None', 'z': ['34', '1.2']}), 'd': nt(a='f', b='100'), 'e': + ├── 'x' --> '1' + └── 'y' --> 'dsfljk' + } + """ values, type_, spec = _c_get_flatted_values_and_spec(v) if type_ is None: return func(values)