-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use conjugates in optimizers to better learn on complex-valued inputs #1776
Conversation
When weights are complex, the deltas to them will also be complex. In all optimizers that need a second-order estimate of gradient statistics, we generally want to use the `x * conj(x)` pattern, rather than `x^2`.
This improves the calculation of error for complex-valued targets
X-ref: pytorch/pytorch#59998 |
There was a discussion about exactly this some weeks ago, but for the life of me I can't remember where or on what platform. @darsnack do you recall this? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks for this.
@ToucheSir Perhaps this Zulip stream. But I recall a longer discussion linking to the PyTorch issue above. I feel like it was #autodiff
or #math-optimization
on Slack, but those messages are long gone.
Is there a good test we could add here? |
This is great, thanks @staticfloat! A good test might be to train with complex valued weights and inputs and checking for smoothness. We might need to do a similar pass over the other optimisers as well. |
@@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict()) | |||
function apply!(o::RMSProp, x, Δ) | |||
η, ρ = o.eta, o.rho | |||
acc = get!(() -> zero(x), o.acc, x)::typeof(x) | |||
@. acc = ρ * acc + (1 - ρ) * Δ^2 | |||
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't Δ * conj(Δ)
the same as abs2(Δ)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, they are not the same, abs2
produces a real value while Δ * conj(Δ)
produces a complex type with zero imaginary part.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I didn't look deeply into these enough to figure out it returning a real value would cause type instability or whatnot, so I just left it in this format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be fine to use abs2
, I think. These are all writing into existing arrays like acc
, so that fixes the overall type.
Surely not a huge effect though. I don't think these broadcasts will ever be seen by Zygote, for instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't x * conj(x)
slower (albeit slightly) than abs2(x)
?
42bf0dd
to
8c3d852
Compare
Added some tests, and confirmed that this makes a big difference when training actual models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm. We need to be careful to preserve the behaviour for complex models moving on and review other losses for the same treatment as mse
. Not a huge fan of the increased complexity in mse
but we definitely need to be correct first.
bors r+
@@ -44,7 +44,8 @@ julia> Flux.mse(y_model, y_true) | |||
""" | |||
function mse(ŷ, y; agg = mean) | |||
_check_sizes(ŷ, y) | |||
agg((ŷ .- y) .^ 2) | |||
error = ŷ .- y | |||
real(agg(error .* conj(error))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DhairyaLGandhi We can simplify this by using agg(abs2.(ŷ .- y))
instead, as abs2()
(as noted elsewhere) always returns a Real
. Would you prefer that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think that would be good.
Build succeeded: |
While the test passes, something is wrong with the ADADelta case here. I think it's passing because the last digits decline a bit, but overall the loss is very close to constant. Graph here: FluxML/Optimisers.jl#47 but also applies to Flux:
|
end | ||
|
||
params = Flux.Params([w]) | ||
opt = opt_ctor(1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This uses the same parameter for all. But ADADelta's first parameter wants to be close to 1, not 0.
When weights are complex, the deltas to them will also be complex. In
all optimizers that need a second-order estimate of gradient statistics,
we generally want to use the
x * conj(x)
pattern, rather thanx^2
.We can see the effect this has on ADAM with the following test:
The training loss before the fix looks like this:
Whereas after both of these commits, it looks like this:
Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic. With a higher learning rate, the "fixed" version is able to learn much faster:
Whereas the unfixed version simply diverges: