Open
Description
Array.at[arr1].add(arr2) crashes
import brainpy.math as bp
x = bp.arange(10)
x.at[x].add(x)
Expected output:
Array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=int32)
Actual output:
File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:817, in index_to_gather(x_shape, idx, normalize_indices)
814 if normalize_indices:
815 advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
816 for e, i, j in advanced_pairs)
--> 817 advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
819 x_axis = 0 # Current axis in x.
820 y_axis = 0 # Current axis in y, before collapsing. See below.
ValueError: not enough values to unpack (expected 3, got 0)
Where the error seems to originate from the index array getting interpreted as a single value instead of an array
Casting to a jax array seems to work as a intermediate solution but is ofc not desired
[ins] In [12]: import brainpy.math as bp
...: x = bp.arange(10)
...: x.at[jnp.array(x)].add(x)
Out[12]: Array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=int32)
Versions:
pip freeze | grep -E 'jax|brainpy|taichi'
brainpy==2.6.0.post20250420
jax==0.5.3
jaxlib==0.5.3
jaxtyping==0.3.1
taichi==1.7.3
upgrading to jax==0.6 doesn't help