Skip to content

Commit

Permalink
Merge pull request #84 from opendilab/fix/integration
Browse files Browse the repository at this point in the history
dev(hansbug): fix bug of jax integration
  • Loading branch information
HansBug authored Mar 16, 2023
2 parents 134e6b7 + 5be0032 commit 131d60c
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 6 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ jobs:
- '3.9'
- '3.10'
- '3.11'
torch-version:
- '1.7.1'
- '1.13.1'
- 'latest'

steps:
- name: Get system version for Linux
Expand Down Expand Up @@ -94,6 +98,15 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-build.txt
pip install -r requirements-test.txt
- name: Install Torch
shell: bash
if: ${{ matrix.torch-version != 'latest' }}
continue-on-error: true
run: |
pip install torch==${{ matrix.torch-version }}
- name: Install Extra Test Requirements
shell: bash
run: |
pip install -r requirements-test-extra.txt
- name: Test the basic environment
shell: bash
Expand Down
9 changes: 5 additions & 4 deletions test/tree/integration/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
jax = None

Expand All @@ -20,10 +21,10 @@ def double(x):
return x * 2 + 1.5

t1 = FastTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'a': np.random.randint(0, 10, (2, 3)) + 1,
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3) + 1,
'y': np.random.randn(2, 3) + 100,
}
})
r1 = double(t1)
Expand All @@ -37,10 +38,10 @@ class MyTreeValue(FastTreeValue):
register_for_jax(MyTreeValue)

t2 = MyTreeValue({
'a': np.random.randint(0, 10, (2, 3)),
'a': np.random.randint(0, 10, (2, 3)) + 1,
'b': {
'x': np.asarray(233.0),
'y': np.random.randn(2, 3) + 1,
'y': np.random.randn(2, 3) + 100,
}
})
r2 = double(t2)
Expand Down
1 change: 1 addition & 0 deletions test/tree/integration/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ImportError, ModuleNotFoundError):
torch = None

Expand Down
5 changes: 4 additions & 1 deletion treevalue/tree/integration/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

try:
import jax
from jax.tree_util import register_pytree_node
except (ModuleNotFoundError, ImportError):
from .cjax import register_for_jax as _original_register_for_jax


@wraps(_original_register_for_jax)
def register_for_jax(cls):
warnings.warn(f'Jax is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Jax doesn\'t have tree_util module due to either not installed '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .cjax import register_for_jax
from ..tree import TreeValue
Expand Down
5 changes: 4 additions & 1 deletion treevalue/tree/integration/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

try:
import torch
from torch.utils._pytree import _register_pytree_node
except (ModuleNotFoundError, ImportError):
from .ctorch import register_for_torch as _original_register_for_torch


@wraps(_original_register_for_torch)
def register_for_torch(cls):
warnings.warn(f'Torch is not installed, registration of {cls!r} will be ignored.')
warnings.warn(f'Pytree module is not included in the Torch installation '
f'or the installed version is too low, '
f'so the registration of {cls!r} will be ignored.')
else:
from .ctorch import register_for_torch
from ..tree import TreeValue
Expand Down

0 comments on commit 131d60c

Please sign in to comment.