Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
- name: Install extra PyPI dependencies
continue-on-error: true
shell: bash
run: |
pip install -r requirements-test-extra.txt
- name: Test the basic environment
shell: bash
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_doc/tree/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ treevalue.tree
tree
func
general
integration
13 changes: 13 additions & 0 deletions docs/source/api_doc/tree/integration.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
treevalue.tree.integration
=======================================

.. py:currentmodule:: treevalue.tree.integration


.. _apidoc_tree_integration_register_for_jax:

register_for_jax
------------------------

.. autofunction:: register_for_jax

3 changes: 2 additions & 1 deletion requirements-test-extra.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
torch>=1.1.0
jax[cpu]>=0.3.25; platform_system != 'Windows'
torch>=1.1.0; python_version < '3.11'
Empty file.
30 changes: 30 additions & 0 deletions test/tree/integration/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from unittest import skipUnless

import numpy as np
import pytest

from treevalue import FastTreeValue

try:
import jax
except (ModuleNotFoundError, ImportError):
jax = None


@pytest.mark.unittest
class TestTreeTreeIntegration:
@skipUnless(jax, 'Jax required.')
def test_jax_double(self):
@jax.jit
def double(x):
return x * 2 + 1.5

t1 = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3)
}
})
assert FastTreeValue.func()(np.isclose)(double(t1), t1 * 2 + 1.5).all() == \
FastTreeValue({'a': True, 'b': {'x': True, 'y': True}})
1 change: 1 addition & 0 deletions treevalue/tree/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .common import raw
from .func import *
from .general import *
from .integration import *
from .tree import *
1 change: 1 addition & 0 deletions treevalue/tree/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .jax import register_for_jax
6 changes: 6 additions & 0 deletions treevalue/tree/integration/cjax.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# distutils:language=c++
# cython:language_level=3

cdef tuple _c_flatten_for_jax(object tv)
cdef object _c_unflatten_for_jax(tuple aux, tuple values)
cpdef void register_for_jax(object cls) except*
47 changes: 47 additions & 0 deletions treevalue/tree/integration/cjax.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# distutils:language=c++
# cython:language_level=3

import cython

from ..tree.flatten cimport _c_flatten, _c_unflatten
from ..tree.tree cimport TreeValue

cdef inline tuple _c_flatten_for_jax(object tv):
cdef list result = []
_c_flatten(tv._detach(), (), result)

cdef list paths = []
cdef list values = []
for path, value in result:
paths.append(path)
values.append(value)

return values, (type(tv), paths)

cdef inline object _c_unflatten_for_jax(tuple aux, tuple values):
cdef object type_
cdef list paths
type_, paths = aux
return type_(_c_unflatten(zip(paths, values)))

@cython.binding(True)
cpdef void register_for_jax(object cls) except*:
"""
Overview:
Register treevalue class for jax.

:param cls: TreeValue class.

Examples::
>>> from treevalue import FastTreeValue, TreeValue, register_for_jax
>>> register_for_jax(TreeValue)
>>> register_for_jax(FastTreeValue)

.. warning::
This method will put a warning message and then do nothing when jax is not installed.
"""
if isinstance(cls, type) and issubclass(cls, TreeValue):
import jax
jax.tree_util.register_pytree_node(cls, _c_flatten_for_jax, _c_unflatten_for_jax)
else:
raise TypeError(f'Registered class should be a subclass of TreeValue, but {cls!r} found.')
19 changes: 19 additions & 0 deletions treevalue/tree/integration/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import warnings
from functools import wraps

try:
import jax
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax


@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
from ..general import FastTreeValue

register_for_jax(TreeValue)
register_for_jax(FastTreeValue)
8 changes: 4 additions & 4 deletions treevalue/tree/tree/flatten.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import cython
from .tree cimport TreeValue
from ..common.storage cimport TreeStorage, _c_undelay_data

cdef void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath

Expand Down Expand Up @@ -44,7 +44,7 @@ cpdef list flatten(TreeValue tree):
_c_flatten(tree._detach(), (), result)
return result

cdef void _c_flatten_values(TreeStorage st, list res) except *:
cdef inline void _c_flatten_values(TreeStorage st, list res) except *:
cdef dict data = st.detach()

cdef str k
Expand Down Expand Up @@ -72,7 +72,7 @@ cpdef list flatten_values(TreeValue tree):
_c_flatten_values(tree._detach(), result)
return result

cdef void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef inline void _c_flatten_keys(TreeStorage st, tuple path, list res) except *:
cdef dict data = st.detach()
cdef tuple curpath

Expand Down Expand Up @@ -102,7 +102,7 @@ cpdef list flatten_keys(TreeValue tree):
_c_flatten_keys(tree._detach(), (), result)
return result

cdef TreeStorage _c_unflatten(object pairs):
cdef inline TreeStorage _c_unflatten(object pairs):
cdef dict raw_data = {}
cdef TreeStorage result = TreeStorage(raw_data)
cdef list stack = []
Expand Down