-
-
Notifications
You must be signed in to change notification settings - Fork 169
Closed
Description
I have found some odd behaviour of the StepTo controller for small time steps, the result is also floating point precision dependent.
Below is a minimum working example. For single precision, the time jumps straight to t1 after 1 time step and the final integration result is incorrect.
import diffrax
import jax.numpy as jnp
import jax
t0 = 0.0
t1 = 1e-9
dt = 2e-12
dt_max = 1e-11
step_t = t0
step_ts = [step_t]
while(step_t < t1):
step_t += dt
step_ts.append(step_t)
dt = min(2*dt,dt_max)
step_ts[-1] = t1
print(step_ts)
stepsize_controller = diffrax.StepTo(ts=step_ts)
def f(t,y,args):
jax.debug.print('{t}',t=t)
return -y/t1
sol = diffrax.diffeqsolve(terms=diffrax.ODETerm(f),
solver=diffrax.Euler(),
t0=t0,
t1=t1,
dt0=None,
y0=jnp.ones(1),
stepsize_controller=stepsize_controller)
print(sol.ts,sol.ys)If I run in double using:
from jax import config
config.update("jax_enable_x64", True)Things are marginally better but the jax.debug.print still suggests that the final few times are clipped to t1.
If I swap t1 to 1.0 and change dt and dt_max accordingly, the correct behaviour is restored. I would however like to run with time steps of this order (nanoseconds) for physics simulation. Cheers!
Metadata
Metadata
Assignees
Labels
No labels