Skip to content

Commit 5f33a66

Browse files
Fix surrogate gradient function and numpy 2.0 compatibility (#679)
* fix surrogate batching * fix numpy 2.0 compatible issue * fix numpy 2.0 compatible issue * updates * fix numpy2.0 compatible issue * Skip the operators tests for GitHub action server * Update * Update * Update * Update test_taichi_based.py * Update test_get_weight_matrix.py --------- Co-authored-by: He Sichao <1310722434@qq.com>
1 parent b9461eb commit 5f33a66

File tree

19 files changed

+124
-88
lines changed

19 files changed

+124
-88
lines changed

.github/workflows/CI.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ jobs:
5151
- name: Test with pytest
5252
run: |
5353
cd brainpy
54-
pytest _src/
54+
export IS_GITHUB_ACTIONS=1 && pytest _src/
5555
5656
5757
test_macos:
@@ -82,7 +82,7 @@ jobs:
8282
- name: Test with pytest
8383
run: |
8484
cd brainpy
85-
pytest _src/
85+
export IS_GITHUB_ACTIONS=1 && pytest _src/
8686
8787
8888
test_windows:
@@ -113,4 +113,4 @@ jobs:
113113
- name: Test with pytest
114114
run: |
115115
cd brainpy
116-
pytest _src/ -p no:faulthandler
116+
set IS_GITHUB_ACTIONS=1 && pytest _src/

brainpy/_src/losses/comparison.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,8 @@ def update(self, input, target):
376376

377377

378378
def nll_loss(input, target, reduction: str = 'mean'):
379-
r"""The negative log likelihood loss.
379+
r"""
380+
The negative log likelihood loss.
380381
381382
The negative log likelihood loss. It is useful to train a classification
382383
problem with `C` classes.

brainpy/_src/math/compat_numpy.py

+6-29
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from .interoperability import *
1111
from .ndarray import Array
1212

13-
1413
__all__ = [
1514
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
1615
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
@@ -92,9 +91,8 @@
9291
'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete',
9392

9493
# unique
95-
'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray',
96-
'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt',
97-
'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
94+
'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt',
95+
'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
9896
'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat',
9997

10098
]
@@ -204,11 +202,12 @@ def ascontiguousarray(a, dtype=None, order=None):
204202
return asarray(a, dtype=dtype, order=order)
205203

206204

207-
def asfarray(a, dtype=np.float_):
205+
def asfarray(a, dtype=None):
208206
if not np.issubdtype(dtype, np.inexact):
209-
dtype = np.float_
207+
dtype = np.float64
210208
return asarray(a, dtype=dtype)
211209

210+
212211
def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
213212
del assume_unique
214213
ar1_flat = ravel(ar1)
@@ -227,6 +226,7 @@ def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
227226
else:
228227
return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1))
229228

229+
230230
# Others
231231
# ------
232232
meshgrid = _compatible_with_brainpy_array(jnp.meshgrid)
@@ -454,7 +454,6 @@ def msort(a):
454454
sometrue = any
455455

456456

457-
458457
def shape(a):
459458
"""
460459
Return the shape of an array.
@@ -648,7 +647,6 @@ def size(a, axis=None):
648647
finfo = jnp.finfo
649648
iinfo = jnp.iinfo
650649

651-
652650
can_cast = _compatible_with_brainpy_array(jnp.can_cast)
653651
choose = _compatible_with_brainpy_array(jnp.choose)
654652
copy = _compatible_with_brainpy_array(jnp.copy)
@@ -678,23 +676,6 @@ def size(a, axis=None):
678676
# Unique APIs
679677
# -----------
680678

681-
add_docstring = np.add_docstring
682-
add_newdoc = np.add_newdoc
683-
add_newdoc_ufunc = np.add_newdoc_ufunc
684-
685-
686-
def array2string(a, max_line_width=None, precision=None,
687-
suppress_small=None, separator=' ', prefix="",
688-
style=np._NoValue, formatter=None, threshold=None,
689-
edgeitems=None, sign=None, floatmode=None, suffix="",
690-
legacy=None):
691-
a = as_numpy(a)
692-
return array2string(a, max_line_width=max_line_width, precision=precision,
693-
suppress_small=suppress_small, separator=separator, prefix=prefix,
694-
style=style, formatter=formatter, threshold=threshold,
695-
edgeitems=edgeitems, sign=sign, floatmode=floatmode, suffix=suffix,
696-
legacy=legacy)
697-
698679

699680
def asscalar(a):
700681
return a.item()
@@ -731,13 +712,9 @@ def common_type(*arrays):
731712
return array_type[0][precision]
732713

733714

734-
disp = np.disp
735-
736715
genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs))
737716
loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs))
738-
739717
info = np.info
740-
issubclass_ = np.issubclass_
741718

742719

743720
def place(arr, mask, vals):

brainpy/_src/math/event/tests/test_event_csrmm.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,27 @@
11
# -*- coding: utf-8 -*-
2-
2+
import os
33
from functools import partial
44

55
import jax
6+
import pytest
67
from absl.testing import parameterized
78

89
import brainpy as bp
910
import brainpy.math as bm
1011

1112
# bm.set_platform('gpu')
1213

14+
import platform
15+
force_test = False # turn on to force test on windows locally
16+
if platform.system() == 'Windows' and not force_test:
17+
pytest.skip('skip windows', allow_module_level=True)
18+
19+
20+
# Skip the test in Github Actions
21+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
22+
if IS_GITHUB_ACTIONS == '1':
23+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
24+
1325
seed = 1234
1426

1527

brainpy/_src/math/event/tests/test_event_csrmv.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
3-
2+
import os
43
from functools import partial
54

65
import jax
@@ -19,6 +18,10 @@
1918
if platform.system() == 'Windows' and not force_test:
2019
pytest.skip('skip windows', allow_module_level=True)
2120

21+
# Skip the test in Github Actions
22+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
23+
if IS_GITHUB_ACTIONS == '1':
24+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
2225

2326
seed = 1234
2427

brainpy/_src/math/jitconn/tests/test_event_matvec.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import os
23

34
import jax
45
import jax.numpy as jnp
@@ -16,6 +17,10 @@
1617
if platform.system() == 'Windows' and not force_test:
1718
pytest.skip('skip windows', allow_module_level=True)
1819

20+
# Skip the test in Github Actions
21+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
22+
if IS_GITHUB_ACTIONS == '1':
23+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
1924

2025
shapes = [(100, 200), (1000, 10)]
2126

brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# -*- coding: utf-8 -*-
2+
import os
3+
24
import jax.numpy as jnp
35
import pytest
46
from absl.testing import parameterized
@@ -12,8 +14,14 @@
1214
import platform
1315

1416
force_test = False # turn on to force test on windows locally
15-
# if platform.system() == 'Windows' and not force_test:
16-
# pytest.skip('skip windows', allow_module_level=True)
17+
if platform.system() == 'Windows' and not force_test:
18+
pytest.skip('skip windows', allow_module_level=True)
19+
20+
# Skip the test in Github Actions
21+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
22+
if IS_GITHUB_ACTIONS == '1':
23+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
24+
1725

1826
shapes = [
1927
(2, 2),

brainpy/_src/math/jitconn/tests/test_matvec.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# -*- coding: utf-8 -*-
2+
import os
23

34
import jax
45
import jax.numpy as jnp
@@ -16,6 +17,10 @@
1617
if platform.system() == 'Windows' and not force_test:
1718
pytest.skip('skip windows', allow_module_level=True)
1819

20+
# Skip the test in Github Actions
21+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
22+
if IS_GITHUB_ACTIONS == '1':
23+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
1924

2025
shapes = [(100, 200), (1000, 10)]
2126

brainpy/_src/math/op_register/tests/test_taichi_based.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
import brainpy.math as bm
66
from brainpy._src.dependency_check import import_taichi
77

8+
import platform
9+
force_test = False # turn on to force test on windows locally
10+
if platform.system() == 'Windows' and not force_test:
11+
pytest.skip('skip windows', allow_module_level=True)
12+
813
ti = import_taichi(error_if_not_found=False)
914
if ti is None:
1015
pytest.skip('no taichi', allow_module_level=True)

brainpy/_src/math/sparse/tests/test_csrmm.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
# -*- coding: utf-8 -*-
22

3+
4+
import os
35
from functools import partial
46

57
import jax
8+
import pytest
69
from absl.testing import parameterized
710

811
import brainpy as bp
912
import brainpy.math as bm
1013

1114
# bm.set_platform('gpu')
1215

16+
import platform
17+
force_test = False # turn on to force test on windows locally
18+
if platform.system() == 'Windows' and not force_test:
19+
pytest.skip('skip windows', allow_module_level=True)
20+
21+
# Skip the test in Github Actions
22+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
23+
if IS_GITHUB_ACTIONS == '1':
24+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
25+
1326
seed = 1234
1427

1528

@@ -133,7 +146,8 @@ def test_homo_grad(self, transpose, shape, homo_data):
133146
argnums=0)
134147
r1 = dense_f1(homo_data)
135148
r2 = jax.grad(sum_op(bm.sparse.csrmm))(
136-
bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
149+
bm.asarray([homo_data]), indices, indptr, matrix,
150+
shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
137151
transpose=transpose)
138152

139153
self.assertTrue(bm.allclose(r1, r2))

brainpy/_src/math/sparse/tests/test_csrmv.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
2+
import os
33
from functools import partial
44

55
import jax
@@ -17,6 +17,10 @@
1717
if platform.system() == 'Windows' and not force_test:
1818
pytest.skip('skip windows', allow_module_level=True)
1919

20+
# Skip the test in Github Actions
21+
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
22+
if IS_GITHUB_ACTIONS == '1':
23+
pytest.skip('Skip the test in Github Actions', allow_module_level=True)
2024

2125
seed = 1234
2226

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# -*- coding: utf-8 -*-
22

33

4-
from .base import *
54
from ._one_input_new import *
65
from ._two_inputs import *

brainpy/_src/math/surrogate/_one_input.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from brainpy._src.math.interoperability import as_jax
1010
from brainpy._src.math.ndarray import Array
11-
from .base import Surrogate
1211

1312
__all__ = [
1413
'sigmoid',
@@ -32,6 +31,16 @@
3231
]
3332

3433

34+
class Surrogate(object):
35+
"""The base surrograte gradient function."""
36+
37+
def __call__(self, *args, **kwargs):
38+
raise NotImplementedError
39+
40+
def __repr__(self):
41+
return f'{self.__class__.__name__}()'
42+
43+
3544
class _OneInpSurrogate(Surrogate):
3645
def __init__(self, forward_use_surrogate=False):
3746
self.forward_use_surrogate = forward_use_surrogate

brainpy/_src/math/surrogate/_one_input_new.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from brainpy._src.math.ndarray import Array
1313

1414
__all__ = [
15+
'Surrogate',
1516
'Sigmoid',
1617
'sigmoid',
1718
'PiecewiseQuadratic',
@@ -61,7 +62,7 @@ def _heaviside_imp(x, dx):
6162

6263

6364
def _heaviside_batching(args, axes):
64-
return heaviside_p.bind(*args), axes
65+
return heaviside_p.bind(*args), [axes[0]]
6566

6667

6768
def _heaviside_jvp(primals, tangents):

brainpy/_src/math/surrogate/base.py

-19
This file was deleted.

brainpy/math/compat_numpy.py

-6
Original file line numberDiff line numberDiff line change
@@ -327,20 +327,14 @@
327327
sort_complex as sort_complex,
328328
unpackbits as unpackbits,
329329
delete as delete,
330-
add_docstring as add_docstring,
331-
add_newdoc as add_newdoc,
332-
add_newdoc_ufunc as add_newdoc_ufunc,
333-
array2string as array2string,
334330
asanyarray as asanyarray,
335331
ascontiguousarray as ascontiguousarray,
336332
asfarray as asfarray,
337333
asscalar as asscalar,
338334
common_type as common_type,
339-
disp as disp,
340335
genfromtxt as genfromtxt,
341336
loadtxt as loadtxt,
342337
info as info,
343-
issubclass_ as issubclass_,
344338
place as place,
345339
polydiv as polydiv,
346340
put as put,

0 commit comments

Comments
 (0)