Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update JVP rule for abs to fix behavior for complex infinite inputs #26086

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Jan 24, 2025

As discussed in #25681, the gradients of abs don't have the correct behavior at complex infinities. As discussed in that issue, and combined with the notation from here, the JVP rule can be re-written as:

f(z) = abs(z) = abs(x + j y)
t = atan2(y, x)
df * (dx + j dx) = Re[(cos(t) - j sin(t)) * (dx + j dx)]

(Note that this isn't quite the same result as what @pearu reports in #25681, but it's what I found when working it through myself, and it has the correct behavior.)

This is straightforward to implement, and at the cost of a performance hit (3 new trig functions), we get stable JVPs throughout the complex plane. I don't expect this is a performance critical computation in many applications, so I think it's probably worth updating the implementation here, but I'd love to hear otherwise if people disagree!

I should note that this also changes the gradient at complex zero (as discussed in #10515 (comment)) to give grad(abs)(0+0j) = 1+0j. I think this is sensible behavior, but there's some chance that this breaks some downstream behavior. (@mattjj may want to comment, having thought about this before!)

@dfm dfm self-assigned this Jan 24, 2025
@dfm dfm requested review from pearu and jakevdp January 24, 2025 18:06
@dfm dfm added the pull ready Ready for copybara import and testing label Jan 24, 2025
Copy link
Collaborator

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the parameters to the test function are not being used. Otherwise, looks good to me! Thanks, @dfm !

"Note that this isn't quite the same result as what @pearu reports in #25681, but it's what I found when working it through myself, and it has the correct behavior."

I believe we still have the same results. The difference is only in atan2(y, x) (this PR) and atan2(x, y) (unconventional usage, my comment in the issue) that leads to differences of final formula but results ought to be the same.

Comment on lines 6307 to 6308
x = jax.lax.complex(jnp.inf, 0.0).astype(dtype)
expected = jax.lax.complex(1.0, 0.0).astype(dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this read

Suggested change
x = jax.lax.complex(jnp.inf, 0.0).astype(dtype)
expected = jax.lax.complex(1.0, 0.0).astype(dtype)
x = input_parts
expected = grad_parts

or similar?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doh! Good catch. Thank you!

@dfm dfm force-pushed the abs-complex-grad branch from 786c82e to 3bb69bd Compare January 24, 2025 19:32
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2025

I worry a bit about the performance hit here. I wonder if we could replace cos(atan2(y, x)) with x / sqrt(x ** 2 + y ** 2) and sin(atan2(y, x)) with y / sqrt(x ** 2 + y ** 2) and use custom_jvp or something similar to handle the (0, 0) case?

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2025

I also wonder if using multiple sinusoidal evaluations to essentially recover abs(x) at x != 0 would cause accuracy issues in some domains.

@dfm
Copy link
Collaborator Author

dfm commented Jan 24, 2025

I wonder if we could replace cos(atan2(y, x)) with x / sqrt(x ** 2 + y ** 2) and sin(atan2(y, x)) with y / sqrt(x ** 2 + y ** 2)

Sorry - I should have been clearer! What you're suggesting here is actually what the existing implementation does, and it causes the problems seen in #25681 because at x+iy = inf + i0 you'll get x / sqrt(x ** 2 + y ** 2) = inf / inf, even though the JVP is well defined. The behavior at x+iy = 0 was not the target of this PR, just a side effect!

I couldn't think of any other ways to get numerically stable gradients in the limits without this change, unless we special cased for infinite inputs. But, in that case, I think we'd need to add quite a few cases for all the possible permutations...

@dfm dfm changed the title Update JVP rule for abs Update JVP rule for abs to fix behavior for complex infinite inputs Jan 24, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2025

I see – that makes sense.

I still worry about potential accuracy issues, especially since we're computing in float32 most of the time, and trig rounding errors could compound.

What if we use a lax.select that chooses between the two approaches depending on the domain of the inputs?

@pearu
Copy link
Collaborator

pearu commented Jan 24, 2025

I think the select alternative to using atan2/sin/cos could be reasonable: when one of real(z) or imag(z) is infinity, the result depends only on the signs of the real and imaginary parts and the corresponding values could be tabulated. There will be 8 values (that correspond to 8 infinity cases) plus one for the finite case for both real and imaginary value of the result. So, in total, there will be 16 select expressions. Most likely, some of these expressions could be combined using symmetries.
Benchmarks would tell which approach is going to be better performance-wise.

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2025

Just a quick accuracy check:

import numpy as np
import jax.numpy as jnp

rng = np.random.default_rng(0)

x = jnp.array(rng.normal(0, 10, 10000), dtype='float32')
y = jnp.array(rng.normal(0, 10, 10000), dtype='float32')

val_trig = jnp.cos(jnp.arctan2(y, x))
val_quad = x / jnp.hypot(x, y)

x = np.array(x, dtype='float64')
y = np.array(y, dtype='float64')
val_true = np.cos(np.arctan2(y, x))

print("trig approach:      max rtol=", max(abs(val_trig - val_true) / val_true))
print("quadratic approach: max rtol=", max(abs(val_quad - val_true) / val_true))
trig approach:      max rtol= 0.0002238464
quadratic approach: max rtol= 1.9903023e-07

The relative accuracy degrades from 2E-7 to 2E-4. I think that's bad enough that we'll want to avoid using the trig approach alone across the whole domain.

@dfm
Copy link
Collaborator Author

dfm commented Jan 24, 2025

Good point about the accuracy, @jakevdp! It's worth noting that this degradation happens only close to the origin, so one option would be to switch when either the real or imag part of the input passes some minimum threshold. But, I'll also take a look at explicitly special casing the infinities. @pearu's point about symmetries is a good one. I'll give this another go next week. Thanks both!!

@jakevdp
Copy link
Collaborator

jakevdp commented Jan 24, 2025

Another question worth thinking about: how does this affect the second derivative at 0 and infinity? Does the trig version result in correct higher-order derivatives at these values?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants