-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update JVP rule for abs to fix behavior for complex infinite inputs #26086
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the parameters to the test function are not being used. Otherwise, looks good to me! Thanks, @dfm !
"Note that this isn't quite the same result as what @pearu reports in #25681, but it's what I found when working it through myself, and it has the correct behavior."
I believe we still have the same results. The difference is only in atan2(y, x)
(this PR) and atan2(x, y)
(unconventional usage, my comment in the issue) that leads to differences of final formula but results ought to be the same.
tests/lax_numpy_test.py
Outdated
x = jax.lax.complex(jnp.inf, 0.0).astype(dtype) | ||
expected = jax.lax.complex(1.0, 0.0).astype(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this read
x = jax.lax.complex(jnp.inf, 0.0).astype(dtype) | |
expected = jax.lax.complex(1.0, 0.0).astype(dtype) | |
x = input_parts | |
expected = grad_parts |
or similar?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doh! Good catch. Thank you!
I worry a bit about the performance hit here. I wonder if we could replace |
I also wonder if using multiple sinusoidal evaluations to essentially recover |
Sorry - I should have been clearer! What you're suggesting here is actually what the existing implementation does, and it causes the problems seen in #25681 because at I couldn't think of any other ways to get numerically stable gradients in the limits without this change, unless we special cased for infinite inputs. But, in that case, I think we'd need to add quite a few cases for all the possible permutations... |
I see – that makes sense. I still worry about potential accuracy issues, especially since we're computing in float32 most of the time, and trig rounding errors could compound. What if we use a |
I think the |
Just a quick accuracy check: import numpy as np
import jax.numpy as jnp
rng = np.random.default_rng(0)
x = jnp.array(rng.normal(0, 10, 10000), dtype='float32')
y = jnp.array(rng.normal(0, 10, 10000), dtype='float32')
val_trig = jnp.cos(jnp.arctan2(y, x))
val_quad = x / jnp.hypot(x, y)
x = np.array(x, dtype='float64')
y = np.array(y, dtype='float64')
val_true = np.cos(np.arctan2(y, x))
print("trig approach: max rtol=", max(abs(val_trig - val_true) / val_true))
print("quadratic approach: max rtol=", max(abs(val_quad - val_true) / val_true))
The relative accuracy degrades from |
Good point about the accuracy, @jakevdp! It's worth noting that this degradation happens only close to the origin, so one option would be to switch when either the real or imag part of the input passes some minimum threshold. But, I'll also take a look at explicitly special casing the infinities. @pearu's point about symmetries is a good one. I'll give this another go next week. Thanks both!! |
Another question worth thinking about: how does this affect the second derivative at 0 and infinity? Does the trig version result in correct higher-order derivatives at these values? |
As discussed in #25681, the gradients of
abs
don't have the correct behavior at complex infinities. As discussed in that issue, and combined with the notation from here, the JVP rule can be re-written as:(Note that this isn't quite the same result as what @pearu reports in #25681, but it's what I found when working it through myself, and it has the correct behavior.)
This is straightforward to implement, and at the cost of a performance hit (3 new trig functions), we get stable JVPs throughout the complex plane. I don't expect this is a performance critical computation in many applications, so I think it's probably worth updating the implementation here, but I'd love to hear otherwise if people disagree!
I should note that this also changes the gradient at complex zero (as discussed in #10515 (comment)) to give
grad(abs)(0+0j) = 1+0j
. I think this is sensible behavior, but there's some chance that this breaks some downstream behavior. (@mattjj may want to comment, having thought about this before!)