Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Norm Layers, Again #1509

Closed
wants to merge 49 commits into from
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
2cd59b4
clean up the history
Feb 16, 2021
0ef033e
bn on cudnn
Feb 16, 2021
1e67700
rm extra utils
Feb 16, 2021
0dbeb39
use simpler test suite first
Feb 16, 2021
04ca0a8
git fixes
Feb 16, 2021
31f7000
refactor norm_forward
Feb 18, 2021
262112a
simplify tests
Feb 20, 2021
c61457c
typose
Feb 20, 2021
0432b82
check reduce for batch
Feb 20, 2021
77a8e87
use normconfig
Feb 21, 2021
d0961a7
use normconfig for other layers
Feb 21, 2021
b091a80
typo
Feb 21, 2021
680e64f
backwards
Feb 21, 2021
4a26559
first pass
Mar 1, 2021
1c38029
Update src/layers/normalise.jl
DhairyaLGandhi Mar 7, 2021
137acf9
don't use single instance in bn
Mar 8, 2021
db559c5
use prev stats to track
Mar 8, 2021
f5c3641
Merge branch 'dg/instance2' of https://github.com/FluxML/Flux.jl into…
Mar 8, 2021
fcea841
track stats tests
Mar 8, 2021
647dcd9
fix BN(3)
Mar 9, 2021
561c94f
check instance norm tests
Mar 10, 2021
332b13b
clean up instance norm
Mar 12, 2021
30d6542
dont reshape eagerly
Mar 12, 2021
1d99fd6
use mean instead of channel
Mar 15, 2021
e3ae11d
unbreak a couple tests
Mar 15, 2021
c66202e
use non corrected variance
Mar 16, 2021
d6fac56
typo
Mar 16, 2021
16d0b96
use train time eval
Mar 16, 2021
e7fe00b
check for dims in getaffine
Mar 16, 2021
9c01dd2
use correct group dims
Mar 22, 2021
9abfe0c
typo
Mar 22, 2021
a621ef6
use trainmode groupnorm test
Mar 22, 2021
e174605
cleanup
Mar 24, 2021
99901a7
use bias and gamma for trainable
Mar 24, 2021
9f481e4
trainable
Mar 24, 2021
e9d89ab
test fixes
Mar 26, 2021
8f3844c
new constructor
DhairyaLGandhi Apr 19, 2021
bf34b73
test conflicts
DhairyaLGandhi Apr 19, 2021
14a6372
conflicts
DhairyaLGandhi Apr 19, 2021
d82c3d3
conflicts
DhairyaLGandhi Apr 19, 2021
8aadf1e
rebase
DhairyaLGandhi Jun 23, 2021
28521c1
rebase
DhairyaLGandhi Jun 23, 2021
19b91b2
size fix
DhairyaLGandhi Jun 24, 2021
0d4605d
space cleanups + show
DhairyaLGandhi Jun 24, 2021
36084e5
add layer norm show methods
DhairyaLGandhi Jun 24, 2021
3c6f1ce
whitespace
DhairyaLGandhi Jun 24, 2021
8f6de19
change some tests
DhairyaLGandhi Jun 25, 2021
c525d4f
use affine as function
DhairyaLGandhi Jun 29, 2021
aa39039
rebase
DhairyaLGandhi Aug 6, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import CUDA.CUDNN: batchnorm, ∇batchnorm

function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
cache=nothing) where T<:Union{Float32, Float64}
function (BN::Flux.BatchNorm)(x::CuArray{T},
cache = nothing) where T<:Union{Float32, Float64}
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
@assert BN.affine throw(ArgumentError("BatchNorm: only affine = true supported on gpu"))
@assert BN.track_stats throw(ArgumentError("BatchNorm: only track_stats = true supported on gpu"))
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
cache = cache, alpha = 1, beta = 0, eps = BN.ϵ,
training = Flux._isactive(BN)))
end

@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
Expand Down
231 changes: 133 additions & 98 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ end

Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.

Does nothing to the input once [`Flux.testmode!`](@ref) is set to `true`.
To apply dropout along certain dimension(s), specify the `dims` keyword.
e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout).
Expand Down Expand Up @@ -118,7 +119,7 @@ testmode!(m::AlphaDropout, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

"""
LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5)
LayerNorm(sz, λ = identity; affine = true, ϵ = 1fe-5)

A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be
used with recurrent hidden states.
Expand All @@ -129,77 +130,109 @@ The input is normalised along the first `length(sz)` dimensions
for tuple `sz`, along the first dimension for integer `sz`.
The input is expected to have first dimensions' size equal to `sz`.

If `affine=true` also applies a learnable shift and rescaling
If `affine = true` also applies a learnable shift and rescaling
as in the [`Diagonal`](@ref) layer.


Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref).
"""
struct LayerNorm{F,D,T,N}
struct LayerNorm{F,D,T,S}
λ::F
diag::D
ϵ::T
size::NTuple{N,Int}
affine::Bool
sz::S
end

function LayerNorm(sz, λ=identity; affine=true, ϵ=1f-5)
sz = sz isa Integer ? (sz,) : sz
diag = affine ? Diagonal(sz...) : nothing
return LayerNorm(λ, diag, ϵ, sz, affine)
function LayerNorm(sz, λ = identity; affine = true, ϵ = 1f-5)
diag = affine ? Diagonal(sz...) : identity
return LayerNorm(λ, diag, ϵ, sz)
end

@functor LayerNorm

function (a::LayerNorm)(x)
x = normalise(x, dims=1:length(a.size), ϵ=a.ϵ)
a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x))
x = normalise(x, dims = 1:length(a.sz), ϵ = a.ϵ)
a.λ.(a.diag(x))
end

function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm($(l.size)")
a.λ == identity || print(io, ", $(a.λ)")
hasaffine(l) || print(io, ", affine=false")
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
print(io, ", $(l.λ)")
print(io, ", affine = $(l.diag)")
print(io, ")")
end

struct NormConfig{A,T}
dims::Vector{Int}
end

NormConfig(affine, track_stats, reduce_dims) = NormConfig{affine, track_stats}(reduce_dims)

function getaffine(nc::NormConfig{true}, sz_x)
n = length(sz_x)
ntuple(i -> i == n-1 ? sz_x[n-1] : 1, n)
end

getaffine(nc::NormConfig{false}, args...) = ()

# For InstanceNorm, GroupNorm, and BatchNorm.
# Compute the statistics on the slices specified by reduce_dims.
# reduce_dims=[1,...,N-2,N] for BatchNorm
# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm
function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N}
if !_isactive(l) && l.track_stats # testmode with tracked stats
function norm_forward(l, x::AbstractArray{T,N}, nc::NormConfig{A, true}) where {A, T, N}
if !_isactive(l) # testmode with tracked stats
stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
μ = reshape(l.μ, stats_shape)
σ² = reshape(l.σ², stats_shape)
else # trainmode or testmode without tracked stats
μ = mean(x; dims=reduce_dims)
σ² = mean((x .- μ).^2; dims=reduce_dims)
if l.track_stats
## update moving mean/std
Zygote.ignore() do
mtm = l.momentum
m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var
μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N))
σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N))
l.μ = (1-mtm) .* l.μ .+ mtm .* μnew
l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new
end
μ = mean(x; dims = getaffine(nc, size(x)))
σ² = sum((x .- μ) .^ 2; dims = getaffine(nc, size(x))) ./ l.chs
μ, σ² = track_stats(x, l.μ, l.σ², l.momentum, reduce_dims = nc.dims)
Zygote.ignore() do
l.μ = reshape(μ, :)
l.σ² = reshape(σ², :)
end
end
if hasaffine(l)
γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β)
else
return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))
end
μ, σ²
end

function norm_forward(l, x::AbstractArray{T,N}, nc::NormConfig{A, false}) where {A, T, N}
μ = mean(x; dims = nc.dims)
σ² = mean((x .- μ) .^ 2; dims = nc.dims)
μ, σ²
end

function track_stats(x::AbstractArray{T,N}, μ, σ², mtm; reduce_dims) where {T,N}
m = prod(size(x)[collect(reduce_dims)])
μnew = vec(N == last(reduce_dims) ? μ : mean(μ, dims = N))
σ²new = vec(N == last(reduce_dims) ? σ² : mean(σ², dims = N))
μ = (1 - mtm) .* μ .+ mtm .* μnew
σ² = (1 - mtm) .* σ² .+ mtm .* (m / (m - one(T))) .* σ²new
μ, σ²
end
@nograd track_stats

function affine(l, x, μ, σ², nc::NormConfig{true})
affine_shape = getaffine(nc, size(x))
γ = reshape(l.γ, affine_shape)
β = reshape(l.β, affine_shape)
μ = reshape(μ, affine_shape)
σ² = reshape(σ², affine_shape)
x̂ = (x .- μ) ./ sqrt.(σ² .+ l.ϵ)
l.λ.(γ .* x̂ .+ β)
end

affine(l, x, μ, σ², nc::NormConfig{false}) = l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ))

# function affine(l, x, μ, σ², affine_shape)
# res = (x .- μ) ./ sqrt.(σ² .+ l.ϵ)
# _affine(l.λ, res, affine_shape)
# end

"""
BatchNorm(channels::Integer, λ=identity;
initβ=zeros, initγ=ones,
ϵ=1f-5, momentum= 0.1f0)
BatchNorm(channels::Integer, λ = identity;
initβ = zeros, initγ = ones,
ϵ = 1f-5, momentum = 0.1f0)

[Batch Normalization](https://arxiv.org/abs/1502.03167) layer.
`channels` should be the size of the channel dimension in your data (see below).
Expand All @@ -211,12 +244,12 @@ it's the usual channel dimension.
`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N`
input slice and normalises the input accordingly.

If `affine=true`, it also applies a shift and a rescale to the input
If `affine = true`, it also applies a shift and a rescale to the input
through to learnable per-channel bias β and scale γ parameters.

After normalisation, elementwise activation `λ` is applied.

If `track_stats=true`, accumulates mean and var statistics in training phase
If `track_stats = true`, accumulates mean and var statistics in training phase
that will be used to renormalize the input in test phase.

Use [`testmode!`](@ref) during inference.
Expand Down Expand Up @@ -245,40 +278,42 @@ mutable struct BatchNorm{F,V,N,W}
active::Union{Bool, Nothing}
end

function BatchNorm(chs::Int, λ=identity;
initβ = i -> zeros(Float32, i),
initγ = i -> ones(Float32, i),
affine=true, track_stats=true,
ϵ=1f-5, momentum=0.1f0)
function BatchNorm(chs::Int, λ = identity;
initβ = i -> zeros(Float32, i),
initγ = i -> ones(Float32, i),
affine = true, track_stats = true,
ϵ = 1f-5, momentum = 0.1f0)

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros(Float32, chs) : nothing
σ² = track_stats ? ones(Float32, chs) : nothing
β = initβ(chs)
γ = initγ(chs)
Comment on lines +289 to +290
Copy link
Member

@CarloLucibello CarloLucibello Mar 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we should have params that we don't use in practice when affine=false?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other thought on this: I know we rejected Affine(BatchNorm(...)) because it is not ergonomically nice. But what about having a field in BatchNorm called affine that is a function. When the affine kwarg is true, then affine = Dense(...) but when it is false, affine = identity? In a Chain, the user still see just BatchNorm and bn.affine will show a Dense which the user will intuitively understand. We also don't end up storing data we don't need, and trainable will automatically be empty.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of like you'd have for an activation function? That does sound pretty appealing, would it make sense to call that field transform or something then?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah exactly. transform would certainly capture the field accurately, but affine might be better here cause people will understand it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well that had been a consideration. Its a bit awkward to also do γ = reshape(l.γ, affine_shape) but I think we can have broadcasting take care of the shape as needed by l.λ.(γ .* x̂ .+ β)

μ = zeros(Float32, chs)
σ² = ones(Float32, chs)
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

return BatchNorm(chs, λ, β, γ,
μ, σ², ϵ, momentum,
μ, σ², ϵ, momentum,
affine, track_stats, nothing)
end

@functor BatchNorm
trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()
# trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : ()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you remove this μ, σ² will end up in the params


function (BN::BatchNorm)(x)
@assert size(x, ndims(x)-1) == BN.chs
N = ndims(x)
N = ndims(x)::Int
@assert size(x, N - 1) == BN.chs
reduce_dims = [1:N-2; N]
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(BN, x; reduce_dims, affine_shape)
nc = NormConfig(BN.affine, BN.track_stats, reduce_dims)
@show nc
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
μ, σ² = norm_forward(BN, x, nc)
affine(BN, x, μ, σ², nc)
end

testmode!(m::BatchNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved
print(io, ", $(l.λ)")
print(io, ", affine = ")
print(io, ")")
end

Expand Down Expand Up @@ -321,41 +356,41 @@ mutable struct InstanceNorm{F,V,N,W}
active::Union{Bool, Nothing}
end

function InstanceNorm(chs::Int, λ=identity;
function InstanceNorm(chs::Int, λ = identity;
initβ = i -> zeros(Float32, i),
initγ = i -> ones(Float32, i),
affine=false, track_stats=false,
ϵ=1f-5, momentum=0.1f0)

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros(Float32, chs) : nothing
σ² = track_stats ? ones(Float32, chs) : nothing

return InstanceNorm(chs, λ, β, γ,
affine = true, track_stats = true,
ϵ = 1f-5, momentum = 0.1f0)

β = initβ(chs)
γ = initγ(chs)
μ = zeros(Float32, chs)
σ² = ones(Float32, chs)
InstanceNorm(chs, λ, β, γ,
μ, σ², ϵ, momentum,
affine, track_stats, nothing)
end

@functor InstanceNorm
trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()
# trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : ()

function (l::InstanceNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == l.chs
N = ndims(x)
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N)
return _norm_layer_forward(l, x; reduce_dims, affine_shape)
nc = NormConfig(l.affine, l.track_stats, reduce_dims)
μ, σ² = norm_forward(l, x, nc)
affine(l, x, μ, σ², nc)
end

testmode!(m::InstanceNorm, mode=true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(l.chs)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
print(io, ", $(l.λ)")
print(io, ", affine = ")
print(io, ")")
end

Expand Down Expand Up @@ -401,20 +436,20 @@ mutable struct GroupNorm{F,V,N,W}
end

@functor GroupNorm
trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()
# trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : ()

function GroupNorm(chs::Int, G::Int, λ=identity;
initβ = (i) -> zeros(Float32, i),
initγ = (i) -> ones(Float32, i),
affine=true, track_stats=false,
ϵ=1f-5, momentum=0.1f0)
function GroupNorm(chs::Int, G::Int, λ = identity;
initβ = i -> zeros(Float32, i),
initγ = i -> ones(Float32, i),
affine = true, track_stats = false,
ϵ = 1f-5, momentum = 0.1f0)
DhairyaLGandhi marked this conversation as resolved.
Show resolved Hide resolved

chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)")

β = affine ? initβ(chs) : nothing
γ = affine ? initγ(chs) : nothing
μ = track_stats ? zeros(Float32, G) : nothing
σ² = track_stats ? ones(Float32, G) : nothing
β = initβ(chs)
γ = initγ(chs)
μ = zeros(Float32, G)
σ² = ones(Float32, G)

return GroupNorm(chs, G, λ,
β, γ,
Expand All @@ -425,33 +460,33 @@ end

function (gn::GroupNorm)(x)
@assert ndims(x) > 2
@assert size(x, ndims(x)-1) == gn.chs
N = ndims(x)
@assert size(x, ndims(x) - 1) == gn.chs
sz = size(x)
x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N])
N = ndims(x)
x = reshape(x, sz[1:N-2]..., sz[N-1] ÷ gn.G, gn.G, sz[N])
reduce_dims = 1:N-2
affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N)
x = _norm_layer_forward(gn, x; reduce_dims, affine_shape)
return reshape(x, sz)
nc = NormConfig(gn.affine, gn.track_stats, reduce_dims)
μ, σ² = norm_forward(gn, x, nc)
res = affine(gn, x, μ, σ², nc)
return reshape(res, sz)
end

testmode!(m::GroupNorm, mode = true) =
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)

function Base.show(io::IO, l::GroupNorm)
print(io, "GroupNorm($(l.chs), $(l.G)")
l.λ == identity || print(io, ", $(l.λ)")
hasaffine(l) || print(io, ", affine=false")
print(io, ", $(l.λ)")
print(io, ", affine = ")
print(io, ")")
end

"""
hasaffine(l)

Return `true` if a normalisation layer has trainable shift and
scale parameters, `false` otherwise.

See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
"""
hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine
# """
# hasaffine(l)
#
# Return `true` if a normalisation layer has trainable shift and
# scale parameters, `false` otherwise.
#
# See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref).
# """
# hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine
Loading