Skip to content

Odd behaviour of StepTo controller for small time steps #657

@aidancrilly

Description

@aidancrilly

I have found some odd behaviour of the StepTo controller for small time steps, the result is also floating point precision dependent.

Below is a minimum working example. For single precision, the time jumps straight to t1 after 1 time step and the final integration result is incorrect.

import diffrax
import jax.numpy as jnp
import jax

t0 = 0.0
t1 = 1e-9

dt = 2e-12
dt_max = 1e-11
step_t = t0
step_ts = [step_t]
while(step_t < t1):
    step_t += dt
    step_ts.append(step_t)
    dt = min(2*dt,dt_max)

step_ts[-1] = t1
print(step_ts)

stepsize_controller = diffrax.StepTo(ts=step_ts)

def f(t,y,args):
    jax.debug.print('{t}',t=t)
    return -y/t1

sol = diffrax.diffeqsolve(terms=diffrax.ODETerm(f),
                    solver=diffrax.Euler(),
                    t0=t0,
                    t1=t1,
                    dt0=None,
                    y0=jnp.ones(1),
                    stepsize_controller=stepsize_controller)

print(sol.ts,sol.ys)

If I run in double using:

from jax import config

config.update("jax_enable_x64", True)

Things are marginally better but the jax.debug.print still suggests that the final few times are clipped to t1.

If I swap t1 to 1.0 and change dt and dt_max accordingly, the correct behaviour is restored. I would however like to run with time steps of this order (nanoseconds) for physics simulation. Cheers!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions