Skip to content

Commit

Permalink
Merge pull request #79 from opendilab/dev/torch
Browse files Browse the repository at this point in the history
dev(hansug): add support for torch integration
  • Loading branch information
HansBug authored Feb 27, 2023
2 parents ac90658 + adc645e commit 704a830
Show file tree
Hide file tree
Showing 11 changed files with 258 additions and 18 deletions.
18 changes: 17 additions & 1 deletion docs/source/api_doc/tree/integration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@ treevalue.tree.integration
.. _apidoc_tree_integration_register_for_jax:

register_for_jax
------------------------
---------------------------

.. autofunction:: register_for_jax


.. _apidoc_tree_integration_register_for_torch:

register_for_torch
---------------------------

.. autofunction:: register_for_torch


.. _apidoc_tree_integration_register_treevalue_class:

register_treevalue_class
---------------------------

.. autofunction:: register_treevalue_class

34 changes: 34 additions & 0 deletions test/tree/integration/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from unittest import skipUnless

import pytest

from treevalue import register_treevalue_class, FastTreeValue

try:
import torch
except (ImportError, ModuleNotFoundError):
torch = None

try:
import jax
except (ModuleNotFoundError, ImportError):
jax = None


@pytest.mark.unittest
class TestTreeIntegrationInit:
@skipUnless(torch and jax, 'Torch and jax required.')
def test_register_custom_class_all(self):
class MyTreeValue(FastTreeValue):
pass

with pytest.warns(None):
register_treevalue_class(MyTreeValue)

@skipUnless(not torch or not jax, 'Not all torch and jax required.')
def test_register_custom_class_some(self):
class MyTreeValue(FastTreeValue):
pass

with pytest.warns(UserWarning):
register_treevalue_class(MyTreeValue)
38 changes: 36 additions & 2 deletions test/tree/integration/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import pytest

from treevalue import FastTreeValue
from treevalue import FastTreeValue, register_for_jax

try:
import jax
Expand All @@ -26,5 +26,39 @@ def double(x):
'y': np.random.randn(2, 3)
}
})
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \
r1 = double(t1)
assert type(r1) is FastTreeValue
assert FastTreeValue.func()(np.isclose)(r1, t1 * 2 + 1.5).all() == \
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})

class MyTreeValue(FastTreeValue):
pass

register_for_jax(MyTreeValue)

t2 = MyTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3)
}
})
r2 = double(t2)
assert type(r2) is MyTreeValue
assert MyTreeValue.func()(np.isclose)(r2, t2 * 2 + 1.5).all() == \
MyTreeValue({'a': True, 'b': {'x': True, 'y': True}})

@skipUnless(jax, 'Jax required.')
def test_error_register(self):
with pytest.raises(TypeError):
register_for_jax(None)
with pytest.raises(TypeError):
register_for_jax(list)

@skipUnless(not jax, 'No jax required')
def test_ignored_register(self):
class MyTreeValueX(FastTreeValue):
pass

with pytest.warns(UserWarning):
register_for_jax(MyTreeValueX)
61 changes: 61 additions & 0 deletions test/tree/integration/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from unittest import skipUnless

import pytest

from treevalue import FastTreeValue, register_for_torch

try:
import torch
except (ImportError, ModuleNotFoundError):
torch = None


@pytest.mark.unittest
class TestTreeIntegrationTorch:
@skipUnless(torch, 'Torch required.')
def test_flatten_and_unflatten(self):
arr1 = torch.randint(0, 10, (2, 3))
arr2 = torch.randn(2, 3)
t1 = FastTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}})

flatted, spec = torch.utils._pytree.tree_flatten(t1)
assert isinstance(flatted, list)
assert len(flatted) == 3
assert torch.isclose(flatted[0], arr1).all()
assert torch.isclose(flatted[1], torch.asarray(233.0)).all()
assert torch.isclose(flatted[2], arr2).all()

newt = torch.utils._pytree.tree_unflatten(flatted, spec)
assert type(newt) == FastTreeValue
assert FastTreeValue.func()(torch.isclose)(t1, newt).all()

class MyTreeValue(FastTreeValue):
pass

register_for_torch(MyTreeValue)
t2 = MyTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}})
flatted, spec = torch.utils._pytree.tree_flatten(t2)
assert isinstance(flatted, list)
assert len(flatted) == 3
assert torch.isclose(flatted[0], arr1).all()
assert torch.isclose(flatted[1], torch.asarray(233.0)).all()
assert torch.isclose(flatted[2], arr2).all()

newt2 = torch.utils._pytree.tree_unflatten(flatted, spec)
assert type(newt2) == MyTreeValue
assert MyTreeValue.func()(torch.isclose)(t2, newt2).all()

@skipUnless(torch, 'Torch required.')
def test_error_register(self):
with pytest.raises(TypeError):
register_for_torch(None)
with pytest.raises(TypeError):
register_for_torch(list)

@skipUnless(not torch, 'No torch required')
def test_ignored_register(self):
class MyTreeValueX(FastTreeValue):
pass

with pytest.warns(UserWarning):
register_for_torch(MyTreeValueX)
19 changes: 19 additions & 0 deletions treevalue/tree/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,20 @@
from typing import Type

from .jax import register_for_jax
from .torch import register_for_torch
from ..tree import TreeValue


def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: bool = True):
"""
Overview:
Register treevalue class into all existing types.
:param cls: TreeValue class.
:param r_jax: Register for jax, default is `True`.
:param r_torch: Register for torch, default is `True`.
"""
if r_jax:
register_for_jax(cls)
if r_torch:
register_for_torch(cls)
5 changes: 5 additions & 0 deletions treevalue/tree/integration/base.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# distutils:language=c++
# cython:language_level=3

cdef tuple _c_flatten_for_integration(object tv)
cdef object _c_unflatten_for_integration(object values, tuple spec)
23 changes: 23 additions & 0 deletions treevalue/tree/integration/base.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# distutils:language=c++
# cython:language_level=3

from ..tree.flatten cimport _c_flatten, _c_unflatten

cdef inline tuple _c_flatten_for_integration(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)

cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)

return values, (type(tv), paths)
pass

cdef inline object _c_unflatten_for_integration(object values, tuple spec):
cdef object type_
cdef list paths
type_, paths = spec
return type_(_c_unflatten(zip(paths, values)))
18 changes: 3 additions & 15 deletions treevalue/tree/integration/cjax.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,14 @@

import cython

from ..tree.flatten cimport _c_flatten, _c_unflatten
from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
from ..tree.tree cimport TreeValue

cdef inline tuple _c_flatten_for_jax(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)

cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)

return values, (type(tv), paths)
return _c_flatten_for_integration(tv)

cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
cdef object type_
cdef list paths
type_, paths = aux
return type_(_c_unflatten(zip(paths, values)))
return _c_unflatten_for_integration(values, aux)

@cython.binding(True)
cpdef void register_for_jax(object cls) except*:
Expand Down
6 changes: 6 additions & 0 deletions treevalue/tree/integration/ctorch.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# distutils:language=c++
# cython:language_level=3

cdef tuple _c_flatten_for_torch(object tv)
cdef object _c_unflatten_for_torch(list values, tuple context)
cpdef void register_for_torch(object cls) except*
35 changes: 35 additions & 0 deletions treevalue/tree/integration/ctorch.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# distutils:language=c++
# cython:language_level=3

import cython

from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration
from ..tree.tree cimport TreeValue

cdef inline tuple _c_flatten_for_torch(object tv):
return _c_flatten_for_integration(tv)

cdef inline object _c_unflatten_for_torch(list values, tuple context):
return _c_unflatten_for_integration(values, context)

@cython.binding(True)
cpdef void register_for_torch(object cls) except*:
"""
Overview:
Register treevalue class for torch's pytree library.
:param cls: TreeValue class.
Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_torch
>>> register_for_torch(TreeValue)
>>> register_for_torch(FastTreeValue)
.. warning::
This method will put a warning message and then do nothing when torch is not installed.
"""
if isinstance(cls, type) and issubclass(cls, TreeValue):
import torch
torch.utils._pytree._register_pytree_node(cls, _c_flatten_for_torch, _c_unflatten_for_torch)
else:
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
19 changes: 19 additions & 0 deletions treevalue/tree/integration/torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import warnings
from functools import wraps

try:
import torch
except (ModuleNotFoundError, ImportError):
from .ctorch import register_for_torch as _original_register_for_torch


@wraps(_original_register_for_torch)
def register_for_torch(cls):
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.')
else:
from .ctorch import register_for_torch
from ..tree import TreeValue
from ..general import FastTreeValue

register_for_torch(TreeValue)
register_for_torch(FastTreeValue)

0 comments on commit 704a830

Please sign in to comment.