Skip to content

Can't backprop over vmapped diffeqsolve (NotImplementedError: Differentiation rule for 'reduce_or' not implemented) #568

@LuggiStruggi

Description

@LuggiStruggi

I can backpropagate over diffeqsolve without any issue, however when I use vmap over some function which includes diffeqsolve I get the following Error:

Traceback (most recent call last):
  File "/home/luggistruggi/Documents/work/test_issue.py", line 45, in <module>
    losses = batched_loss_fn(batch_params)
  File "/home/luggistruggi/Documents/work/test_issue.py", line 36, in batched_loss_fn
    return jax.vmap(single_loss_fn)(params)
  File "/home/luggistruggi/Documents/work/test_issue.py", line 20, in single_loss_fn
    sol = diffrax.diffeqsolve(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 1401, in diffeqsolve
    final_state, aux_stats = adjoint.loop(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_adjoint.py", line 294, in loop
    final_state = self._loop(
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 640, in loop
    event_happened = jnp.any(jnp.stack(flat_mask))
  File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/jax/_src/numpy/reductions.py", line 681, in any
    return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Differentiation rule for 'reduce_or' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

Here the code which produced the error:

import jax
import jax.numpy as jnp
import diffrax
import optimistix as optx

def dynamics(t, y, args):
    param = args
    return param - y


def event_fn(t, y, args, **kwargs):
    return y - 1.5

def single_loss_fn(param):
    solver = diffrax.Euler()
    root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
    event = diffrax.Event(event_fn, root_finder)
    term = diffrax.ODETerm(dynamics)

    sol = diffrax.diffeqsolve(
        term,
        solver=solver,
        t0=0.0,
        t1=2.0,
        dt0=0.1,
        y0=0.0,
        args=param,
        event=event,
        max_steps=1000,
    )

    final_y = sol.ys[-1]
    return param**2 + final_y**2

def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(single_loss_fn)(params)

def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
    return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)


if __name__ == "__main__":
    batch_params = jnp.array([1.0, 2.0, 3.0])

    losses = batched_loss_fn(batch_params)
    print("batched_loss_fn =", losses)

    grads = grad_fn(batch_params)
    print("grad_fn =", grads)

Any suggestions on how to avoid this or implement this myself? Thank you so much :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions