-
-
Notifications
You must be signed in to change notification settings - Fork 169
Use 100 ULP's to clip timesteps close to t1 #660
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Open to better ideas for error checking of traced values, but hopefully the possibility of having |
|
Nice, this LGTM!
I'm not sure I understand the concern in single precision here?
I'm also not sure I understand this concern, can you expand? CC @aidancrilly @varchasgopalaswamy could you both give this PR a go? It'd be good to know if this fixes your problems! |
|
Sorry I got the single precision comment a bit wrong,1 ULP is essentially a rtol of about 6e-8 (I was thinking 6e-6 as I misread my own comments) meaning that going with 1000 ULP would limit us to 16000 timesteps, that probably seems as good as anyone could expect for single precision. Do you think I should bump up to 1000 ULPs? My second comment was if t1-t0 was within tolerance then the clipped t1 would be on the other side of t0 (i.e. in most cases below) and results will be very confusing. Just as if it would have been nice if we could have raised an error or warning if dt0 was less than 1e-10 with the old implementation, I thought I'd try warn her but that doesn't seem possible. However as our tolerance (at least with double precision) is incredibly low I don't think it's really worth worrying about. |
Hmm, it sounds like you're computing import equinox.internal as eqxi
import jax
import jax.lax as lax
import jax.numpy as jnp
x = jnp.array(1.)
@jax.jit
def run(start):
def cond(carry):
_, value = carry
return value > 0
def body(carry):
counter, value = carry
return counter + 1, eqxi.prevbefore(value)
counter, _ = lax.while_loop(cond, body, (0, start))
return counter
print(run(x)) # 1056964609
Right, so if we did want this then I think this actually is possible. We could definitely precompute the t-clip value at the start of But as a practical matter I think I like the current approach which basically just enforces that we take a single step. (We could add a test for this case, perhaps? Not super important.) |
|
Hi guys, I can confirm that this fix works in my MWE in #657 as well as in the main code of interest. Plotting out the results from MWE as a function of t1 from this PR vs 0.7.0 release: Cheers Jonathan! |
|
With regards to the warning, yes, I think the current behaviour should be fine (taking a single step to With regards to the maximum number of steps, I meant maximum number of steps with a constant step size such that there is never any clipping. Clipping is a sometimes surprising effect when using a constant step size (although this might be avoided by the constant step size PR I'll draft after this) and it would be good to think carefully when the surprise is warranted. In such a case, I've plotted the maximum number of constant steps for 100 ULP's and 1000 ULP's as a function of My gut feel is that expecting more than 10k steps with single precision without any surprises is asking too much and that we can probably bump to 1000 ULP's. WDYT? |
Ah, right, I get you now. I think my feeling is that ideally we'd determine the number based on what gives us enough range to avoid instability in dense interpolation calculations. I could see 100 being too small but perhaps 1000 is too large? Actually, that does make me now think: maybe the 'correct' fix would be to (a) adjust WDYT? |
|
Would Do you feel we are introducing any additional instability compared to the status quo with 1000 ULP's? If not, could we implement this (or 200/500 ULP's if you prefer) as a temporary fix and then work on the long term fix after that? As I said I would be happy to help with part (a). |
I don't think so. This already steps to precise locations, so for this the clipping is honestly a bit weird.
I don't think so. I agree, let's merge this (with 100 ULPs) as it's certainly at least an improvement on what we had before! I've just triggered CI, to merge once it (hopefully) passes. |
|
Looks like root finder no longer requires pyright ignores for some reason, I have removed the suggested ignore comments as pyright passes on my side. (Pre-commit did run fine so not sure whether this is because it's checking more files or due to an upstream change.) |
|
I'm guessing, the failing test is a compounded rounding error from the I'm going to see if I can push the constant step size fix first in a separate PR. |
|
Changed to target dev and merged recent changes. Tests seem to pass except for Jax bug. Hopefully good to go! |
|
Awesome! LGTM, and merged :D |
* Use 100 ULP's to clip timesteps close to t1 * test that t1-t0 > 100 ULP's * revert testing as t1 is traced * remove unnecessary pyright ignores
* Use 100 ULP's to clip timesteps close to t1 * test that t1-t0 > 100 ULP's * revert testing as t1 is traced * remove unnecessary pyright ignores




This allows timesteps within 100 ULP's of t1 for
ConstantStepSizewhich can be smaller than the previously fixed absolute tolerance of1e-10for double precision. This fixes #632 and #657 as discussed.Using
for _ in range(100)rather than multiplying by100.0takes about 7µs rather than 3µs which would be undesirable if we had to perform the computation at every timestep. Fortunately, this is not necessary and we just computet1_clip_floorat the beginning of_integrate.loop.I am ambivalent with whether a multiplier of 100 is optimal here, 1000 would be more than fine for double precision but would limit total number of steps to about 1000 in single precision which seems overly restrictive.
Happy to follow up with a
ConstantStepSizePR once this is merged.