-
-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex numbers alla Flux 1776 #47
Merged
Merged
Changes from 9 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
75f9b72
change broadcasting macro & remove bugs
mcabbott 0429c4a
fix OADAM
mcabbott 3a8e291
log the loss during tests
mcabbott edcecc5
complex numbers alla Flux 1776
mcabbott c1d5cbd
fixed?
mcabbott 688b74c
fix Momentum, Nesterov
mcabbott 886fe12
found the bug, fixed
mcabbott 5e5ce16
rm plotting code
mcabbott 9f97c5d
uncomment one test
mcabbott a7c199e
comments
mcabbott File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,16 @@ RULES = [ | |
name(o) = typeof(o).name.name | ||
name(o::OptimiserChain) = join(name.(o.opts), " → ") | ||
|
||
LOG = Dict() | ||
|
||
loggradient(o) = (f, xs...) -> begin | ||
y, dxs = Zygote.withgradient(f, xs...) | ||
push!(get!(() -> Float32[], LOG, name(o)), y) | ||
dxs # save the loss, return the gradient | ||
end | ||
|
||
@testset "independence" begin | ||
empty!(LOG) | ||
@testset "$(name(o))" for o in RULES | ||
w = randn(10, 10) | ||
w′ = randn(10, 10) | ||
|
@@ -28,22 +37,23 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ") | |
st = Optimisers.setup(o, w) | ||
for t = 1:10^5 | ||
x = rand(10) | ||
gs = gradient(w -> iloss(x, w, w′), w) | ||
gs = loggradient(o)(w -> iloss(x, w, w′), w) | ||
st, w = Optimisers.update!(st, w, gs...) | ||
end | ||
@test iloss(rand(10, 10), w, w′) < 0.01 | ||
end | ||
end | ||
|
||
@testset verbose=true "simple sum" begin | ||
empty!(LOG) | ||
@testset "$(name(o))" for o in RULES | ||
m = shuffle!(reshape(1:64, 8, 8) .+ 0.0) | ||
s = Optimisers.setup(o, m) | ||
for _ in 1:10^5 | ||
g = gradient(x -> sum(abs2, x + x'), m)[1] | ||
g = loggradient(o)(x -> sum(abs2, x + x'), m)[1] | ||
s, m = Optimisers.update!(s, m, g) | ||
end | ||
# @test sum(m) < sum(1:64) | ||
@test sum(m) < sum(1:64) | ||
if sum(m) < 1 | ||
@test sum(m) < 1 | ||
else | ||
|
@@ -54,21 +64,23 @@ end | |
end | ||
|
||
@testset "original" begin | ||
empty!(LOG) | ||
@testset "$(name(o))" for o in RULES | ||
w′ = (α = rand(3, 3), β = rand(3, 3)) | ||
w = (α = 5rand(3, 3), β = rand(3, 3)) | ||
st = Optimisers.setup(o, w) | ||
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2) | ||
@test loss(w, w′) > 1 | ||
for i = 1:10^4 | ||
gs = gradient(x -> loss(x, w′), w) | ||
gs = loggradient(o)(x -> loss(x, w′), w) | ||
st, w = Optimisers.update(st, w, gs...) | ||
end | ||
@test loss(w, w′) < 0.001 | ||
end | ||
end | ||
|
||
@testset verbose=true "StaticArrays" begin | ||
empty!(LOG) | ||
@testset "$(name(o))" for o in RULES | ||
W1 = @SMatrix randn(10, 10) | ||
b1 = @SVector randn(10) | ||
|
@@ -82,7 +94,7 @@ end | |
@test s_loss(model, x, y) > 10 | ||
state = Optimisers.setup(o, model) | ||
for t = 1:10^3 | ||
g = gradient(m -> s_loss(m, x, y), model)[1] | ||
g = loggradient(o)(m -> s_loss(m, x, y), model)[1] | ||
state, model = Optimisers.update!(state, model, g) | ||
end | ||
if o isa Descent | ||
|
@@ -94,7 +106,7 @@ end | |
end | ||
end | ||
|
||
@testset verbose=true "element types" begin | ||
@testset "element types" begin | ||
@testset "$(name(o))" for o in RULES | ||
marray = (Float32[1,2], Float64[3,4], Float16[5,6]) | ||
types = map(eltype, marray) | ||
|
@@ -166,3 +178,55 @@ end | |
end | ||
end | ||
|
||
@testset "with complex numebers: Flux#1776" begin | ||
empty!(LOG) | ||
@testset "$(name(opt))" for opt in [ | ||
# The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: | ||
ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2), | ||
Comment on lines
+184
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the fix, BTW. |
||
# These weren't in Flux PR: | ||
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2), | ||
] | ||
# Our "model" is just a complex number | ||
model = (w = zeros(ComplexF64, 1),) | ||
|
||
# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x` | ||
function loss(m) | ||
# Deterministic training data is the best training data | ||
x = ones(1, 1) + 1im*ones(1, 1) | ||
# Manually implement `mse()` to allow demonstration of brokenness | ||
# on older Flux builds that don't have a fixed `mse()` | ||
return sum(abs2.(m.w * x .- conj(x))) | ||
end | ||
@test loss(model) ≈ 2.0 | ||
|
||
state = Optimisers.setup(opt, model) | ||
|
||
# Train for 10 iterations, enforcing that loss is monotonically decreasing | ||
last_loss = Inf | ||
for idx in 1:10 | ||
grads = loggradient(opt)(loss, model) | ||
state, model = Optimisers.update!(state, model, grads...) | ||
opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end | ||
@test loss(model) < last_loss | ||
last_loss = loss(model) | ||
end | ||
@test loss(model) < 1.9 | ||
|
||
# Repeat with StaticArrays | ||
static_model = (w = SA[0.0 + 0im],) | ||
static_state = Optimisers.setup(opt, static_model) | ||
function static_loss(m) | ||
x = hcat(SA[1.0 + im]) | ||
sum(abs2.(m.w * x .- conj(x))) | ||
end | ||
last_loss = Inf | ||
for idx in 1:10 | ||
grads = gradient(static_loss, static_model) | ||
static_state, static_model = Optimisers.update!(static_state, static_model, grads...) | ||
opt isa Union{Momentum, Nesterov} && idx > 8 && continue | ||
@test static_loss(static_model) < last_loss | ||
last_loss = static_loss(static_model) | ||
end | ||
@test static_loss(static_model) < 1.9 | ||
end | ||
end |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This (
LOG
) doesn't seem to be in use anywhere, is it still necessary?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just makes debugging easier if you can plot things form the tests you just ran. It's not strictly necessary but also doesn't really get in the way, I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough. A comment with what you just described would help then.