-
-
Notifications
You must be signed in to change notification settings - Fork 169
Description
First off, thanks for this (and your other) libraries! Very useful for scientific programming with JAX.
I'm having an odd issue with the time-stepping. If the time step is too small, it looks like there is some kind of numerical error that creeps in. Here's a simple MWE with the ODE being exponential decay
from diffrax import diffeqsolve, Dopri5, Euler,ODETerm, SaveAt, ConstantStepSize
import jax
jax.config.update("jax_enable_x64", True)
vector_field = lambda t, y, args: -y / args[0]
term = ODETerm(vector_field)
solver = Euler()
saveat = SaveAt(t0=True, t1=True)
stepsize_controller = ConstantStepSize()
tau = 1e-11
dt0 = tau / 1000
sol = diffeqsolve(term, solver, t0=0, t1=tau, dt0=dt0, y0=1.0, saveat=saveat,
args=(tau,),
stepsize_controller=stepsize_controller, max_steps=10000)
print(sol.ts[0:10])
print(sol.ys[0:10]) It's set up so that when you set the decay time tau, it'll run for 1000 steps for one decay time, so the solution will always be 1/e.
For tau over about 1e-9, you get the right answer. If tau < 1e-9, you start getting some pretty bad numerical error. For instance, tau = 1e-10 gives the answer as 1e-4.
If I switch to Bosh3, the issue is less - but still, there is a tau dependence on the error
The same thing run through scipy.integrate.solve_ivp gives no tau dependence on the solution error (though of course, diffrax is like two orders of magnitude faster!). I jitted the input to scipy.integrate.solve_ivp and it made no difference - so I guess the issue is somewhere in diffrax?
Can I reparametrize the ODE I'm working with to avoid small numbers? Yes, but it does make me worry about numerical stability. Any advice?

