Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow ForwardDiff in BatchNorm's track_stats #2127

Merged
merged 4 commits into from
Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owne

using Zygote, ChainRulesCore
using Zygote: Params, @adjoint, gradient, pullback, @nograd
using Zygote.ForwardDiff: value
export gradient

# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,8 @@ function _track_stats!(
μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N))

bn.μ = res_mtm .* bn.μ .+ mtm .* μnew
bn.σ² = res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new
bn.μ .= value.(res_mtm .* bn.μ .+ mtm .* μnew)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also updates bn.μ in-place, rather than replacing the field. We know it's mutable as the constructor made it so.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we leave a comment explaining why value is there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, done.

bn.σ² .= value.(res_mtm .* bn.σ² .+ mtm .* (m / (m - one(V))) .* σ²new)
return nothing
end

Expand Down
12 changes: 11 additions & 1 deletion test/layers/normalisation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Flux, Test, Statistics
using Zygote: pullback
using Zygote: pullback, ForwardDiff

evalwgrad(f, x...) = pullback(f, x...)[1]

Expand Down Expand Up @@ -463,3 +463,13 @@ end
m1 = Dropout(0.5)
@test Zygote.hessian_reverse(sum∘m1, [1.0,2.0,3.0]) == zeros(3, 3)
end

@testset "ForwardDiff" begin
bn = BatchNorm(3)
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
iszero(bn.μ) # true, but ideally it would automatically choose trainmode
mcabbott marked this conversation as resolved.
Show resolved Hide resolved
Flux.trainmode!(bn)
@test ForwardDiff.jacobian(bn, rand(Float32, 3, 4)) isa Matrix{Float32}
@test !iszero(bn.μ)
end