From 874bd09c1193792eec4037f93e8807d3e1d2d424 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 27 Feb 2023 15:17:29 +0800 Subject: [PATCH 1/2] dev(hansug): add support for torch integration --- docs/source/api_doc/tree/integration.rst | 18 ++++++- test/tree/integration/test_init.py | 34 +++++++++++++ test/tree/integration/test_jax.py | 38 ++++++++++++++- test/tree/integration/test_torch.py | 61 ++++++++++++++++++++++++ treevalue/tree/integration/__init__.py | 19 ++++++++ treevalue/tree/integration/base.pxd | 5 ++ treevalue/tree/integration/base.pyx | 23 +++++++++ treevalue/tree/integration/cjax.pyx | 18 ++----- treevalue/tree/integration/ctorch.pxd | 6 +++ treevalue/tree/integration/ctorch.pyx | 35 ++++++++++++++ treevalue/tree/integration/torch.py | 19 ++++++++ 11 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 test/tree/integration/test_init.py create mode 100644 test/tree/integration/test_torch.py create mode 100644 treevalue/tree/integration/base.pxd create mode 100644 treevalue/tree/integration/base.pyx create mode 100644 treevalue/tree/integration/ctorch.pxd create mode 100644 treevalue/tree/integration/ctorch.pyx create mode 100644 treevalue/tree/integration/torch.py diff --git a/docs/source/api_doc/tree/integration.rst b/docs/source/api_doc/tree/integration.rst index 093c27ee5a..1e8bebf6af 100644 --- a/docs/source/api_doc/tree/integration.rst +++ b/docs/source/api_doc/tree/integration.rst @@ -7,7 +7,23 @@ treevalue.tree.integration .. _apidoc_tree_integration_register_for_jax: register_for_jax ------------------------- +--------------------------- .. autofunction:: register_for_jax + +.. _apidoc_tree_integration_register_for_torch: + +register_for_torch +--------------------------- + +.. autofunction:: register_for_torch + + +.. _apidoc_tree_integration_register_treevalue_class: + +register_treevalue_class +--------------------------- + +.. autofunction:: register_treevalue_class + diff --git a/test/tree/integration/test_init.py b/test/tree/integration/test_init.py new file mode 100644 index 0000000000..c2359b58a5 --- /dev/null +++ b/test/tree/integration/test_init.py @@ -0,0 +1,34 @@ +from unittest import skipUnless + +import pytest + +from treevalue import register_treevalue_class, FastTreeValue + +try: + import torch +except (ImportError, ModuleNotFoundError): + torch = None + +try: + import jax +except (ModuleNotFoundError, ImportError): + jax = None + + +@pytest.mark.unittest +class TestTreeIntegrationInit: + @skipUnless(torch and jax, 'Torch and jax required.') + def test_register_custom_class_all(self): + class MyTreeValue(FastTreeValue): + pass + + with pytest.warns(None): + register_treevalue_class(MyTreeValue) + + @skipUnless(not torch or not jax, 'Not all torch and jax required.') + def test_register_custom_class_some(self): + class MyTreeValue(FastTreeValue): + pass + + with pytest.warns(UserWarning): + register_treevalue_class(MyTreeValue) diff --git a/test/tree/integration/test_jax.py b/test/tree/integration/test_jax.py index 97e6e6745f..53d2d13a8f 100644 --- a/test/tree/integration/test_jax.py +++ b/test/tree/integration/test_jax.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from treevalue import FastTreeValue +from treevalue import FastTreeValue, register_for_jax try: import jax @@ -26,5 +26,39 @@ def double(x): 'y': np.random.randn(2, 3) } }) - assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \ + r1 = double(t1) + assert type(r1) is FastTreeValue + assert FastTreeValue.func()(np.isclose)(r1, t1 * 2 + 1.5).all() == \ FastTreeValue({'a': True, 'b': {'x': True, 'y': True}}) + + class MyTreeValue(FastTreeValue): + pass + + register_for_jax(MyTreeValue) + + t2 = MyTreeValue({ + 'a': np.random.randint(0, 10, (2, 3)), + 'b': { + 'x': np.asarray(233.0), + 'y': np.random.randn(2, 3) + } + }) + r2 = double(t2) + assert type(r2) is MyTreeValue + assert MyTreeValue.func()(np.isclose)(r2, t2 * 2 + 1.5).all() == \ + MyTreeValue({'a': True, 'b': {'x': True, 'y': True}}) + + @skipUnless(jax, 'Jax required.') + def test_error_register(self): + with pytest.raises(TypeError): + register_for_jax(None) + with pytest.raises(TypeError): + register_for_jax(list) + + @skipUnless(not jax, 'No jax required') + def test_ignored_register(self): + class MyTreeValueX(FastTreeValue): + pass + + with pytest.warns(UserWarning): + register_for_jax(MyTreeValueX) diff --git a/test/tree/integration/test_torch.py b/test/tree/integration/test_torch.py new file mode 100644 index 0000000000..2537f83ea4 --- /dev/null +++ b/test/tree/integration/test_torch.py @@ -0,0 +1,61 @@ +from unittest import skipUnless + +import pytest + +from treevalue import FastTreeValue, register_for_torch + +try: + import torch +except (ImportError, ModuleNotFoundError): + torch = None + + +@pytest.mark.unittest +class TestTreeIntegrationTorch: + @skipUnless(torch, 'Torch required.') + def test_flatten_and_unflatten(self): + arr1 = torch.randint(0, 10, (2, 3)) + arr2 = torch.randn(2, 3) + t1 = FastTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}}) + + flatted, spec = torch.utils._pytree.tree_flatten(t1) + assert isinstance(flatted, list) + assert len(flatted) == 3 + assert torch.isclose(flatted[0], arr1).all() + assert torch.isclose(flatted[1], torch.asarray(233.0)).all() + assert torch.isclose(flatted[2], arr2).all() + + newt = torch.utils._pytree.tree_unflatten(flatted, spec) + assert type(newt) == FastTreeValue + assert FastTreeValue.func()(torch.isclose)(t1, newt).all() + + class MyTreeValue(FastTreeValue): + pass + + register_for_torch(MyTreeValue) + t2 = MyTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}}) + flatted, spec = torch.utils._pytree.tree_flatten(t2) + assert isinstance(flatted, list) + assert len(flatted) == 3 + assert torch.isclose(flatted[0], arr1).all() + assert torch.isclose(flatted[1], torch.asarray(233.0)).all() + assert torch.isclose(flatted[2], arr2).all() + + newt2 = torch.utils._pytree.tree_unflatten(flatted, spec) + assert type(newt2) == MyTreeValue + assert MyTreeValue.func()(torch.isclose)(t2, newt2).all() + + @skipUnless(torch, 'Torch required.') + def test_error_register(self): + with pytest.raises(TypeError): + register_for_torch(None) + with pytest.raises(TypeError): + register_for_torch(list) + + @skipUnless(not torch, 'No torch required') + def test_ignored_register(self): + class MyTreeValueX(FastTreeValue): + pass + + with pytest.warns(UserWarning): + register_for_torch(MyTreeValueX) diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py index 81a5f62d25..11231ab66b 100644 --- a/treevalue/tree/integration/__init__.py +++ b/treevalue/tree/integration/__init__.py @@ -1 +1,20 @@ +from typing import Type + from .jax import register_for_jax +from .torch import register_for_torch +from ..tree import TreeValue + + +def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: bool = True): + """ + Overview: + Register treevalue class into all existing types. + + :param cls: TreeValue class. + :param r_jax: Register for jax, default is `True`. + :param r_torch: Register for torch, default is `True`. + """ + if r_jax: + register_for_torch(cls) + if r_torch: + register_for_jax(cls) diff --git a/treevalue/tree/integration/base.pxd b/treevalue/tree/integration/base.pxd new file mode 100644 index 0000000000..bb119a7d9c --- /dev/null +++ b/treevalue/tree/integration/base.pxd @@ -0,0 +1,5 @@ +# distutils:language=c++ +# cython:language_level=3 + +cdef tuple _c_flatten_for_integration(object tv) +cdef object _c_unflatten_for_integration(object values, tuple spec) diff --git a/treevalue/tree/integration/base.pyx b/treevalue/tree/integration/base.pyx new file mode 100644 index 0000000000..43bb761eb6 --- /dev/null +++ b/treevalue/tree/integration/base.pyx @@ -0,0 +1,23 @@ +# distutils:language=c++ +# cython:language_level=3 + +from ..tree.flatten cimport _c_flatten, _c_unflatten + +cdef inline tuple _c_flatten_for_integration(object tv): + cdef list result = [] + _c_flatten(tv._detach(), (), result) + + cdef list paths = [] + cdef list values = [] + for path, value in result: + paths.append(path) + values.append(value) + + return values, (type(tv), paths) + pass + +cdef inline object _c_unflatten_for_integration(object values, tuple spec): + cdef object type_ + cdef list paths + type_, paths = spec + return type_(_c_unflatten(zip(paths, values))) diff --git a/treevalue/tree/integration/cjax.pyx b/treevalue/tree/integration/cjax.pyx index a170dbd97c..0a289f21bd 100644 --- a/treevalue/tree/integration/cjax.pyx +++ b/treevalue/tree/integration/cjax.pyx @@ -3,26 +3,14 @@ import cython -from ..tree.flatten cimport _c_flatten, _c_unflatten +from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration from ..tree.tree cimport TreeValue cdef inline tuple _c_flatten_for_jax(object tv): - cdef list result = [] - _c_flatten(tv._detach(), (), result) - - cdef list paths = [] - cdef list values = [] - for path, value in result: - paths.append(path) - values.append(value) - - return values, (type(tv), paths) + return _c_flatten_for_integration(tv) cdef inline object _c_unflatten_for_jax(tuple aux, tuple values): - cdef object type_ - cdef list paths - type_, paths = aux - return type_(_c_unflatten(zip(paths, values))) + return _c_unflatten_for_integration(values, aux) @cython.binding(True) cpdef void register_for_jax(object cls) except*: diff --git a/treevalue/tree/integration/ctorch.pxd b/treevalue/tree/integration/ctorch.pxd new file mode 100644 index 0000000000..016bcac644 --- /dev/null +++ b/treevalue/tree/integration/ctorch.pxd @@ -0,0 +1,6 @@ +# distutils:language=c++ +# cython:language_level=3 + +cdef tuple _c_flatten_for_torch(object tv) +cdef object _c_unflatten_for_torch(list values, tuple context) +cpdef void register_for_torch(object cls) except* \ No newline at end of file diff --git a/treevalue/tree/integration/ctorch.pyx b/treevalue/tree/integration/ctorch.pyx new file mode 100644 index 0000000000..1d86924f52 --- /dev/null +++ b/treevalue/tree/integration/ctorch.pyx @@ -0,0 +1,35 @@ +# distutils:language=c++ +# cython:language_level=3 + +import cython + +from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration +from ..tree.tree cimport TreeValue + +cdef inline tuple _c_flatten_for_torch(object tv): + return _c_flatten_for_integration(tv) + +cdef inline object _c_unflatten_for_torch(list values, tuple context): + return _c_unflatten_for_integration(values, context) + +@cython.binding(True) +cpdef void register_for_torch(object cls) except*: + """ + Overview: + Register treevalue class for torch's pytree library. + + :param cls: TreeValue class. + + Examples:: + >>> from treevalue import FastTreeValue, TreeValue, register_for_torch + >>> register_for_torch(TreeValue) + >>> register_for_torch(FastTreeValue) + + .. warning:: + This method will put a warning message and then do nothing when torch is not installed. + """ + if isinstance(cls, type) and issubclass(cls, TreeValue): + import torch + torch.utils._pytree._register_pytree_node(cls, _c_flatten_for_torch, _c_unflatten_for_torch) + else: + raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.') diff --git a/treevalue/tree/integration/torch.py b/treevalue/tree/integration/torch.py new file mode 100644 index 0000000000..cfbb764c17 --- /dev/null +++ b/treevalue/tree/integration/torch.py @@ -0,0 +1,19 @@ +import warnings +from functools import wraps + +try: + import torch +except (ModuleNotFoundError, ImportError): + from .ctorch import register_for_torch as _original_register_for_torch + + + @wraps(_original_register_for_torch) + def register_for_torch(cls): + warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.') +else: + from .ctorch import register_for_torch + from ..tree import TreeValue + from ..general import FastTreeValue + + register_for_torch(TreeValue) + register_for_torch(FastTreeValue) From b2b01bf338e9233f86661a9c2600a5bb0e4a8758 Mon Sep 17 00:00:00 2001 From: HansBug Date: Mon, 27 Feb 2023 18:10:39 +0800 Subject: [PATCH 2/2] dev(hansbug): fix this bug --- treevalue/tree/integration/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/treevalue/tree/integration/__init__.py b/treevalue/tree/integration/__init__.py index 11231ab66b..7bf1079447 100644 --- a/treevalue/tree/integration/__init__.py +++ b/treevalue/tree/integration/__init__.py @@ -15,6 +15,6 @@ def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: :param r_torch: Register for torch, default is `True`. """ if r_jax: - register_for_torch(cls) - if r_torch: register_for_jax(cls) + if r_torch: + register_for_torch(cls)