Skip to content

Conversation

@jpbrodrick89
Copy link
Contributor

This allows timesteps within 100 ULP's of t1 for ConstantStepSize which can be smaller than the previously fixed absolute tolerance of 1e-10 for double precision. This fixes #632 and #657 as discussed.

Using for _ in range(100) rather than multiplying by 100.0 takes about 7µs rather than 3µs which would be undesirable if we had to perform the computation at every timestep. Fortunately, this is not necessary and we just compute t1_clip_floor at the beginning of _integrate.loop.

I am ambivalent with whether a multiplier of 100 is optimal here, 1000 would be more than fine for double precision but would limit total number of steps to about 1000 in single precision which seems overly restrictive.

Happy to follow up with a ConstantStepSize PR once this is merged.

@jpbrodrick89
Copy link
Contributor Author

Open to better ideas for error checking of traced values, but hopefully the possibility of having t0 and t1 is so vanishing rare we don't have to test for it.

@patrick-kidger
Copy link
Owner

Nice, this LGTM!

I am ambivalent with whether a multiplier of 100 is optimal here, 1000 would be more than fine for double precision but would limit total number of steps to about 1000 in single precision which seems overly restrictive.

I'm not sure I understand the concern in single precision here?

but hopefully the possibility of having t0 and t1 is so vanishing rare we don't have to test for it.

I'm also not sure I understand this concern, can you expand?


CC @aidancrilly @varchasgopalaswamy could you both give this PR a go? It'd be good to know if this fixes your problems!

@jpbrodrick89
Copy link
Contributor Author

Sorry I got the single precision comment a bit wrong,1 ULP is essentially a rtol of about 6e-8 (I was thinking 6e-6 as I misread my own comments) meaning that going with 1000 ULP would limit us to 16000 timesteps, that probably seems as good as anyone could expect for single precision. Do you think I should bump up to 1000 ULPs?

My second comment was if t1-t0 was within tolerance then the clipped t1 would be on the other side of t0 (i.e. in most cases below) and results will be very confusing. Just as if it would have been nice if we could have raised an error or warning if dt0 was less than 1e-10 with the old implementation, I thought I'd try warn her but that doesn't seem possible. However as our tolerance (at least with double precision) is incredibly low I don't think it's really worth worrying about.

@patrick-kidger
Copy link
Owner

meaning that going with 1000 ULP would limit us to 16000 timesteps

Hmm, it sounds like you're computing (1 / 6e-8) / 1000, i.e. you're considering specifically a solve over the time interval [0, 1]? If so then bear in mind that (a) it's only the last 1000 ULPs that matter (so we need to subtract by that, not divide) and that (b) 1 ULP is a not a uniform size over the time interval. Brute-forcing the total number of ULPs between 0 and 1 (and entirely ignoring subnormal numbers):

import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp

x = jnp.array(1.)

@jax.jit
def run(start):
    def cond(carry):
        _, value = carry
        return value > 0
    def body(carry):
        counter, value = carry
        return counter + 1, eqxi.prevbefore(value)
    counter, _ = lax.while_loop(cond, body, (0, start))
    return counter

print(run(x))  # 1056964609

I thought I'd try warn

Right, so if we did want this then I think this actually is possible. We could definitely precompute the t-clip value at the start of diffeqsolve, and then use an eqx.error_if to handle the case that (t0 > tclip) & (t0 == t1).

But as a practical matter I think I like the current approach which basically just enforces that we take a single step. (We could add a test for this case, perhaps? Not super important.)

@aidancrilly
Copy link

Hi guys,

I can confirm that this fix works in my MWE in #657 as well as in the main code of interest.

Plotting out the results from MWE as a function of t1 from this PR vs 0.7.0 release:

image

image

Cheers Jonathan!

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jul 1, 2025

With regards to the warning, yes, I think the current behaviour should be fine (taking a single step to tclip) I think we can leave that.

With regards to the maximum number of steps, I meant maximum number of steps with a constant step size such that there is never any clipping. Clipping is a sometimes surprising effect when using a constant step size (although this might be avoided by the constant step size PR I'll draft after this) and it would be good to think carefully when the surprise is warranted. In such a case, dt0 must be smaller than t1 - t1_clip_floor giving the maximum number of such steps as (t1 - t0)/(t1 - t1_clip_floor) which would be approximately 160k for 100 ULP's, and 16k for 1000 ULP's.

I've plotted the maximum number of constant steps for 100 ULP's and 1000 ULP's as a function of t1-t0 for t0 = 0.0 and t0 = 1.0:

image

image

My gut feel is that expecting more than 10k steps with single precision without any surprises is asking too much and that we can probably bump to 1000 ULP's. WDYT?

@patrick-kidger
Copy link
Owner

patrick-kidger commented Jul 6, 2025

With regards to the maximum number of steps, I meant maximum number of steps with a constant step size such that there is never any clipping.

Ah, right, I get you now.

I think my feeling is that ideally we'd determine the number based on what gives us enough range to avoid instability in dense interpolation calculations. I could see 100 being too small but perhaps 1000 is too large?

Actually, that does make me now think: maybe the 'correct' fix would be to (a) adjust ConstantStepSize as discussed; (b) adjust the dense interpolation routines to be robust to short intervals; (c) remove clipping altogether? (This is certainly more work though.)

WDYT?

@jpbrodrick89
Copy link
Contributor Author

Would StepTo also need a dedicated fix with that proposal?

Do you feel we are introducing any additional instability compared to the status quo with 1000 ULP's? If not, could we implement this (or 200/500 ULP's if you prefer) as a temporary fix and then work on the long term fix after that? As I said I would be happy to help with part (a).

@patrick-kidger
Copy link
Owner

Would StepTo also need a dedicated fix with that proposal?

I don't think so. This already steps to precise locations, so for this the clipping is honestly a bit weird.

Do you feel we are introducing any additional instability compared to the status quo with 1000 ULP's? If not, could we implement this (or 200/500 ULP's if you prefer) as a temporary fix and then work on the long term fix after that?

I don't think so. I agree, let's merge this (with 100 ULPs) as it's certainly at least an improvement on what we had before! I've just triggered CI, to merge once it (hopefully) passes.

@jpbrodrick89
Copy link
Contributor Author

Looks like root finder no longer requires pyright ignores for some reason, I have removed the suggested ignore comments as pyright passes on my side. (Pre-commit did run fine so not sure whether this is because it's checking more files or due to an upstream change.)

@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jul 15, 2025

I'm guessing, the failing test is a compounded rounding error from the + dt with ConstantStepSize. Increasing max_steps by 1 or the number of ULP's to 500 removes DOES NOT remove the failure.

I'm going to see if I can push the constant step size fix first in a separate PR.

@jpbrodrick89 jpbrodrick89 changed the base branch from main to dev July 30, 2025 21:18
@jpbrodrick89
Copy link
Contributor Author

jpbrodrick89 commented Jul 30, 2025

Changed to target dev and merged recent changes. Tests seem to pass except for Jax bug. Hopefully good to go!

@patrick-kidger patrick-kidger merged commit 2709081 into patrick-kidger:dev Aug 2, 2025
1 of 2 checks passed
@patrick-kidger
Copy link
Owner

Awesome! LGTM, and merged :D

patrick-kidger pushed a commit that referenced this pull request Aug 31, 2025
* Use 100 ULP's to clip timesteps close to t1

* test that t1-t0 > 100 ULP's

* revert testing as t1 is traced

* remove unnecessary pyright ignores
patrick-kidger pushed a commit that referenced this pull request Feb 1, 2026
* Use 100 ULP's to clip timesteps close to t1

* test that t1-t0 > 100 ULP's

* revert testing as t1 is traced

* remove unnecessary pyright ignores
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Issue with small time steps

3 participants