Skip to content

.at[] indexing doesn't work with brainpy array #743

Open
@llandsmeer

Description

@llandsmeer

Array.at[arr1].add(arr2) crashes

import brainpy.math as bp
x = bp.arange(10)
x.at[x].add(x)

Expected output:

Array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

Actual output:

File ~/.local/lib/python3.10/site-packages/jax/_src/numpy/indexing.py:817, in index_to_gather(x_shape, idx, normalize_indices)
    814   if normalize_indices:
    815     advanced_pairs = ((_normalize_index(e, x_shape[j]), i, j)
    816                       for e, i, j in advanced_pairs)
--> 817   advanced_indexes, idx_advanced_axes, x_advanced_axes = zip(*advanced_pairs)
    819 x_axis = 0  # Current axis in x.
    820 y_axis = 0  # Current axis in y, before collapsing. See below.

ValueError: not enough values to unpack (expected 3, got 0)

Where the error seems to originate from the index array getting interpreted as a single value instead of an array

Casting to a jax array seems to work as a intermediate solution but is ofc not desired

[ins] In [12]: import brainpy.math as bp
          ...: x = bp.arange(10)
          ...: x.at[jnp.array(x)].add(x)
Out[12]: Array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18], dtype=int32)

Versions:

pip freeze | grep -E 'jax|brainpy|taichi'
brainpy==2.6.0.post20250420
jax==0.5.3
jaxlib==0.5.3
jaxtyping==0.3.1
taichi==1.7.3

upgrading to jax==0.6 doesn't help

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions