-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #79 from opendilab/dev/torch
dev(hansug): add support for torch integration
- Loading branch information
Showing
11 changed files
with
258 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from unittest import skipUnless | ||
|
||
import pytest | ||
|
||
from treevalue import register_treevalue_class, FastTreeValue | ||
|
||
try: | ||
import torch | ||
except (ImportError, ModuleNotFoundError): | ||
torch = None | ||
|
||
try: | ||
import jax | ||
except (ModuleNotFoundError, ImportError): | ||
jax = None | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestTreeIntegrationInit: | ||
@skipUnless(torch and jax, 'Torch and jax required.') | ||
def test_register_custom_class_all(self): | ||
class MyTreeValue(FastTreeValue): | ||
pass | ||
|
||
with pytest.warns(None): | ||
register_treevalue_class(MyTreeValue) | ||
|
||
@skipUnless(not torch or not jax, 'Not all torch and jax required.') | ||
def test_register_custom_class_some(self): | ||
class MyTreeValue(FastTreeValue): | ||
pass | ||
|
||
with pytest.warns(UserWarning): | ||
register_treevalue_class(MyTreeValue) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from unittest import skipUnless | ||
|
||
import pytest | ||
|
||
from treevalue import FastTreeValue, register_for_torch | ||
|
||
try: | ||
import torch | ||
except (ImportError, ModuleNotFoundError): | ||
torch = None | ||
|
||
|
||
@pytest.mark.unittest | ||
class TestTreeIntegrationTorch: | ||
@skipUnless(torch, 'Torch required.') | ||
def test_flatten_and_unflatten(self): | ||
arr1 = torch.randint(0, 10, (2, 3)) | ||
arr2 = torch.randn(2, 3) | ||
t1 = FastTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}}) | ||
|
||
flatted, spec = torch.utils._pytree.tree_flatten(t1) | ||
assert isinstance(flatted, list) | ||
assert len(flatted) == 3 | ||
assert torch.isclose(flatted[0], arr1).all() | ||
assert torch.isclose(flatted[1], torch.asarray(233.0)).all() | ||
assert torch.isclose(flatted[2], arr2).all() | ||
|
||
newt = torch.utils._pytree.tree_unflatten(flatted, spec) | ||
assert type(newt) == FastTreeValue | ||
assert FastTreeValue.func()(torch.isclose)(t1, newt).all() | ||
|
||
class MyTreeValue(FastTreeValue): | ||
pass | ||
|
||
register_for_torch(MyTreeValue) | ||
t2 = MyTreeValue({'a': arr1, 'b': {'x': torch.asarray(233.0), 'y': arr2}}) | ||
flatted, spec = torch.utils._pytree.tree_flatten(t2) | ||
assert isinstance(flatted, list) | ||
assert len(flatted) == 3 | ||
assert torch.isclose(flatted[0], arr1).all() | ||
assert torch.isclose(flatted[1], torch.asarray(233.0)).all() | ||
assert torch.isclose(flatted[2], arr2).all() | ||
|
||
newt2 = torch.utils._pytree.tree_unflatten(flatted, spec) | ||
assert type(newt2) == MyTreeValue | ||
assert MyTreeValue.func()(torch.isclose)(t2, newt2).all() | ||
|
||
@skipUnless(torch, 'Torch required.') | ||
def test_error_register(self): | ||
with pytest.raises(TypeError): | ||
register_for_torch(None) | ||
with pytest.raises(TypeError): | ||
register_for_torch(list) | ||
|
||
@skipUnless(not torch, 'No torch required') | ||
def test_ignored_register(self): | ||
class MyTreeValueX(FastTreeValue): | ||
pass | ||
|
||
with pytest.warns(UserWarning): | ||
register_for_torch(MyTreeValueX) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,20 @@ | ||
from typing import Type | ||
|
||
from .jax import register_for_jax | ||
from .torch import register_for_torch | ||
from ..tree import TreeValue | ||
|
||
|
||
def register_treevalue_class(cls: Type[TreeValue], r_jax: bool = True, r_torch: bool = True): | ||
""" | ||
Overview: | ||
Register treevalue class into all existing types. | ||
:param cls: TreeValue class. | ||
:param r_jax: Register for jax, default is `True`. | ||
:param r_torch: Register for torch, default is `True`. | ||
""" | ||
if r_jax: | ||
register_for_jax(cls) | ||
if r_torch: | ||
register_for_torch(cls) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
cdef tuple _c_flatten_for_integration(object tv) | ||
cdef object _c_unflatten_for_integration(object values, tuple spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
from ..tree.flatten cimport _c_flatten, _c_unflatten | ||
|
||
cdef inline tuple _c_flatten_for_integration(object tv): | ||
cdef list result = [] | ||
_c_flatten(tv._detach(), (), result) | ||
|
||
cdef list paths = [] | ||
cdef list values = [] | ||
for path, value in result: | ||
paths.append(path) | ||
values.append(value) | ||
|
||
return values, (type(tv), paths) | ||
pass | ||
|
||
cdef inline object _c_unflatten_for_integration(object values, tuple spec): | ||
cdef object type_ | ||
cdef list paths | ||
type_, paths = spec | ||
return type_(_c_unflatten(zip(paths, values))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
cdef tuple _c_flatten_for_torch(object tv) | ||
cdef object _c_unflatten_for_torch(list values, tuple context) | ||
cpdef void register_for_torch(object cls) except* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
# distutils:language=c++ | ||
# cython:language_level=3 | ||
|
||
import cython | ||
|
||
from .base cimport _c_flatten_for_integration, _c_unflatten_for_integration | ||
from ..tree.tree cimport TreeValue | ||
|
||
cdef inline tuple _c_flatten_for_torch(object tv): | ||
return _c_flatten_for_integration(tv) | ||
|
||
cdef inline object _c_unflatten_for_torch(list values, tuple context): | ||
return _c_unflatten_for_integration(values, context) | ||
|
||
@cython.binding(True) | ||
cpdef void register_for_torch(object cls) except*: | ||
""" | ||
Overview: | ||
Register treevalue class for torch's pytree library. | ||
:param cls: TreeValue class. | ||
Examples:: | ||
>>> from treevalue import FastTreeValue, TreeValue, register_for_torch | ||
>>> register_for_torch(TreeValue) | ||
>>> register_for_torch(FastTreeValue) | ||
.. warning:: | ||
This method will put a warning message and then do nothing when torch is not installed. | ||
""" | ||
if isinstance(cls, type) and issubclass(cls, TreeValue): | ||
import torch | ||
torch.utils._pytree._register_pytree_node(cls, _c_flatten_for_torch, _c_unflatten_for_torch) | ||
else: | ||
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import warnings | ||
from functools import wraps | ||
|
||
try: | ||
import torch | ||
except (ModuleNotFoundError, ImportError): | ||
from .ctorch import register_for_torch as _original_register_for_torch | ||
|
||
|
||
@wraps(_original_register_for_torch) | ||
def register_for_torch(cls): | ||
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.') | ||
else: | ||
from .ctorch import register_for_torch | ||
from ..tree import TreeValue | ||
from ..general import FastTreeValue | ||
|
||
register_for_torch(TreeValue) | ||
register_for_torch(FastTreeValue) |