Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use conjugates in optimizers to better learn on complex-valued inputs #1776

Merged
merged 3 commits into from
Nov 30, 2021

Conversation

staticfloat
Copy link
Contributor

@staticfloat staticfloat commented Nov 23, 2021

When weights are complex, the deltas to them will also be complex. In
all optimizers that need a second-order estimate of gradient statistics,
we generally want to use the x * conj(x) pattern, rather than x^2.

We can see the effect this has on ADAM with the following test:

begin
    # This model will learn `W = I` and `bias = 0`
    complex_init(dims...) = Flux.glorot_uniform(dims...) .+ 1im .* Flux.glorot_uniform(dims...)
    model = Chain(
        Dense(4, 4, tanh; init=complex_init),
        Dense(4, 16, tanh; init=complex_init),
        Dense(16, 4, tanh; init=complex_init),
        Dense(4, 4, tanh; init=complex_init),
    )

    # Loss function; note we don't need the `abs()` if we update `Flux.Losses.mse()` as below
    function loss(x)
        return abs.(Flux.Losses.mse(model(x), x))
    end

    # Keep track of loss from epoch to epoch
    losses = Float64[]
    dataset = [(randn(ComplexF32, 4, 10),)]
    params = Flux.params(model)
    opt = Flux.Optimise.ADAM(0.001)
    for epoch_idx in 1:10000
        Flux.train!(loss, params, dataset, opt)
        epoch_loss = loss(dataset[1][1])
        push!(losses, epoch_loss)
        if epoch_idx % 100 == 0
            @info("epoch done", epoch_idx, epoch_loss)
        end
    end

    # Plot the loss
    fig = Figure()
    meta_ax = Axis(fig[1,1])
    lines!(meta_ax, log.(losses); label="Training loss")
    fig[1,2] = Legend(fig, meta_ax, "Learning Stats")
    fig
end

The training loss before the fix looks like this:

without_workaround

Whereas after both of these commits, it looks like this:

with_workaround

Note that while the absolute value of the loss is actually comparable in this simple example, the loss landscape is significantly more chaotic. With a higher learning rate, the "fixed" version is able to learn much faster:

download-1

Whereas the unfixed version simply diverges:

download-2

When weights are complex, the deltas to them will also be complex.  In
all optimizers that need a second-order estimate of gradient statistics,
we generally want to use the `x * conj(x)` pattern, rather than `x^2`.
This improves the calculation of error for complex-valued targets
@staticfloat
Copy link
Contributor Author

X-ref: pytorch/pytorch#59998

@ToucheSir
Copy link
Member

There was a discussion about exactly this some weeks ago, but for the life of me I can't remember where or on what platform. @darsnack do you recall this?

darsnack
darsnack previously approved these changes Nov 23, 2021
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks for this.

@ToucheSir Perhaps this Zulip stream. But I recall a longer discussion linking to the PyTorch issue above. I feel like it was #autodiff or #math-optimization on Slack, but those messages are long gone.

@darsnack
Copy link
Member

Is there a good test we could add here?

@DhairyaLGandhi
Copy link
Member

This is great, thanks @staticfloat! A good test might be to train with complex valued weights and inputs and checking for smoothness. We might need to do a similar pass over the other optimisers as well.

@@ -141,7 +141,7 @@ RMSProp(η = 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ^2
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't Δ * conj(Δ) the same as abs2(Δ)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, they are not the same, abs2 produces a real value while Δ * conj(Δ) produces a complex type with zero imaginary part.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I didn't look deeply into these enough to figure out it returning a real value would cause type instability or whatnot, so I just left it in this format.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fine to use abs2, I think. These are all writing into existing arrays like acc, so that fixes the overall type.

Surely not a huge effect though. I don't think these broadcasts will ever be seen by Zygote, for instance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't x * conj(x) slower (albeit slightly) than abs2(x)?

@staticfloat
Copy link
Contributor Author

Added some tests, and confirmed that this makes a big difference when training actual models.

Copy link
Member

@DhairyaLGandhi DhairyaLGandhi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. We need to be careful to preserve the behaviour for complex models moving on and review other losses for the same treatment as mse. Not a huge fan of the increased complexity in mse but we definitely need to be correct first.

bors r+

@@ -44,7 +44,8 @@ julia> Flux.mse(y_model, y_true)
"""
function mse(ŷ, y; agg = mean)
_check_sizes(ŷ, y)
agg((ŷ .- y) .^ 2)
error = ŷ .- y
real(agg(error .* conj(error)))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DhairyaLGandhi We can simplify this by using agg(abs2.(ŷ .- y)) instead, as abs2() (as noted elsewhere) always returns a Real. Would you prefer that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think that would be good.

@bors
Copy link
Contributor

bors bot commented Nov 30, 2021

Build succeeded:

@bors bors bot merged commit cbc1275 into master Nov 30, 2021
@mcabbott mcabbott deleted the sf/momentum_is_complex branch January 30, 2022 16:45
@mcabbott
Copy link
Member

While the test passes, something is wrong with the ADADelta case here. I think it's passing because the last digits decline a bit, but overall the loss is very close to constant.

Graph here: FluxML/Optimisers.jl#47 but also applies to Flux:

julia> opt_ctor = ADADelta;

julia> LOG = [];

julia> for idx in 1:10
                grads = Flux.gradient(loss, params)
                push!(LOG, loss()) 
                last_loss = loss()
                Flux.update!(opt, params, grads)
            end

julia> LOG
10-element Vector{Any}:
 2.0000000993025404
 2.0000000537095355
 2.000000116407697
 2.000000066966267
 2.000000134152019
 2.000000081155827
 2.000000152564042
 2.000000096253135
 2.000000171664382
 2.0000001122397846

end

params = Flux.Params([w])
opt = opt_ctor(1e-2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This uses the same parameter for all. But ADADelta's first parameter wants to be close to 1, not 0.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants