-
-
Notifications
You must be signed in to change notification settings - Fork 156
Closed
Description
I can backpropagate over diffeqsolve without any issue, however when I use vmap over some function which includes diffeqsolve I get the following Error:
Traceback (most recent call last):
File "/home/luggistruggi/Documents/work/test_issue.py", line 45, in <module>
losses = batched_loss_fn(batch_params)
File "/home/luggistruggi/Documents/work/test_issue.py", line 36, in batched_loss_fn
return jax.vmap(single_loss_fn)(params)
File "/home/luggistruggi/Documents/work/test_issue.py", line 20, in single_loss_fn
sol = diffrax.diffeqsolve(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 1401, in diffeqsolve
final_state, aux_stats = adjoint.loop(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_adjoint.py", line 294, in loop
final_state = self._loop(
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/diffrax/_integrate.py", line 640, in loop
event_happened = jnp.any(jnp.stack(flat_mask))
File "/home/luggistruggi/miniconda3/envs/spiking/lib/python3.12/site-packages/jax/_src/numpy/reductions.py", line 681, in any
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Differentiation rule for 'reduce_or' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
Here the code which produced the error:
import jax
import jax.numpy as jnp
import diffrax
import optimistix as optx
def dynamics(t, y, args):
param = args
return param - y
def event_fn(t, y, args, **kwargs):
return y - 1.5
def single_loss_fn(param):
solver = diffrax.Euler()
root_finder = optx.Newton(1e-2, 1e-2, optx.rms_norm)
event = diffrax.Event(event_fn, root_finder)
term = diffrax.ODETerm(dynamics)
sol = diffrax.diffeqsolve(
term,
solver=solver,
t0=0.0,
t1=2.0,
dt0=0.1,
y0=0.0,
args=param,
event=event,
max_steps=1000,
)
final_y = sol.ys[-1]
return param**2 + final_y**2
def batched_loss_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.vmap(single_loss_fn)(params)
def grad_fn(params: jnp.ndarray) -> jnp.ndarray:
return jax.grad(lambda p: jnp.sum(batched_loss_fn(p)))(params)
if __name__ == "__main__":
batch_params = jnp.array([1.0, 2.0, 3.0])
losses = batched_loss_fn(batch_params)
print("batched_loss_fn =", losses)
grads = grad_fn(batch_params)
print("grad_fn =", grads)
Any suggestions on how to avoid this or implement this myself? Thank you so much :)
Metadata
Metadata
Assignees
Labels
No labels