-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct chainrules for abs2, abs, conj and angle #196
Conversation
Seth Axen pointed out that the Zygote chainrules PR does this: https://github.com/FluxML/Zygote.jl/blob/bf913a2a8ed616242e2f5378fbe598b289dd550a/src/lib/number.jl#L26-L30 to get correct answers. I think this is a reasonable way to go about it rather than using Wirtinger definitions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
I think these functions are a case where for Complex
we need to define frule
and rrule
explicitly instead of using @scalar_rule
. The adjoint rules @oxinabox defined in that PR look right to me:
# we intentionally define these here rather than falling back on ChainRules.jl
# because ChainRules doesn't really handle nonanalytic complex functions
@adjoint abs(x::Real) = abs(x), Δ -> (real(Δ)*sign(x),)
@adjoint abs(x::Complex) = abs(x), Δ -> (real(Δ)*x/abs(x),)
@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
However, I'm pretty sure the frule
s for abs(x::Complex)
and abs2(x::Complex)
are
function frule((Δx,), abs2, x::ComplexF64)
return abs2(x), 2 * (real(x) * real(Δx) + imag(x) * imag(Δx))
end
function frule((Δx,), abs, x::ComplexF64)
Ω = abs(x)
return Ω, (real(x) * real(Δx) + imag(x) * imag(Δx)) / Ω
end
(confirmed by FD), and I don't see a good way to generate both these frule
s and the rrule
s from a single scalar rule.
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment on what it returns for the complex case?
Given ChainRules.jl does not support complex rules. Maybe it is better to remove or comment out complex rules. |
This is incorrect. We definitely support complex rules. |
I don't see |
We just don't use Wirtingers. There's nothing stopping your from writing rules that work with complex numbers though. |
But are |
I don't believe so, but that's not the point. I'm not arguing that our |
Does anyone know what the corresponding |
They should be the ones you linked to above in Zygote. Assuming Zygote's conjugation conventions of course. |
I don't understand how can you define complex rules without using 2x2 matrix or Wirtinger, if the function is not complex analytic. |
I believe all of these rules in ChainRules assume that the tangents and cotangents are derivatives of the primal with respect to a real scalar or a real scalar with respect to a primal, respectively. Or equivalently, which is why no Wirtinger is needed. If complex differentiation is what is needed, you just call the pushforward/pullback twice to fill the Jacobian. This is basically what Zygote does: https://fluxml.ai/Zygote.jl/latest/complex/ |
I see. That makes sense. But then how can we warn the user that the function is not analytic? I don't think silently giving the wrong answer is a good idea. |
It's not the wrong answer. It's just that naively asking for It's like how in D(f)(v) * [1, 0] you wouldn't know the full Jacobian. You need In this case, the whole Jacobian can be obtained from It would however be a good idea to make this thing more clear in the docs somehow. |
It's not a documentation issue. The information that the function is not holomorphic is never forwarded. So an AD system doesn't know it needs special handling for the non-holomorphic case. Though, I definitely like this approach of handling complex AD. |
That's fair. In this approach, you need to check if the function is holomorphic by checking the Cauchy-Riemann equations. And it'll be a bit wasteful (although you can reuse the pullback when not mutating). But on the upside, the rules are simpler, and I don't think there's anything preventing future implementation of Wirtinger derivatives or something equivalent. In case you didn't see it, this discussion is relevant: https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Nick Robinson <npr251@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Co-authored-by: Seth Axen <seth.axen@gmail.com>
Okay, so this PR now adopts the subgradient convention where in the situations that might cause functions like the gradient of It also has changed to the point of view for |
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
Co-authored-by: Simon Etter <ettersi@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Can you increment the version number to v0.7.0-DEV? I think we can merge this soon but hold on a release until the coming PR with compatibility for ChainRulesCore v0.9 is merged.
Closes #195.
Pending some more thoughts in https://discourse.julialang.org/t/taking-complex-autodiff-seriously-in-chainrules/39317/49 and / or an issue here, we should consider adding something along the lines of
Current state of the PR is described here: #196 (comment)