Skip to content

Commit

Permalink
fix and refactor normalization layers
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Feb 4, 2021
1 parent eddc9cb commit ff18866
Show file tree
Hide file tree
Showing 14 changed files with 584 additions and 370 deletions.
28 changes: 14 additions & 14 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ version = "0.3.3"

[[Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "87491f7d03ae1b423a353aff99cf61a45e3c993a"
git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.1.0"
version = "3.2.0"

[[Artifacts]]
deps = ["Pkg"]
Expand Down Expand Up @@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "0.3.4+0"

[[DataAPI]]
git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84"
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
version = "1.4.0"
version = "1.5.0"

[[DataStructures]]
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
Expand Down Expand Up @@ -134,9 +134,9 @@ version = "0.1.3"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays"]
git-tree-sha1 = "8bd8e47ff5d34b20f0aa9641988eb660590008bc"
git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708"
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
version = "0.11.0"
version = "0.11.1"

[[FixedPointNumbers]]
deps = ["Statistics"]
Expand All @@ -146,9 +146,9 @@ version = "0.8.4"

[[ForwardDiff]]
deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
git-tree-sha1 = "c26b56e9b9f0687f7ca887f6b6ded03d269e0e35"
git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
version = "0.10.15"
version = "0.10.16"

[[Functors]]
deps = ["MacroTools"]
Expand Down Expand Up @@ -227,9 +227,9 @@ version = "0.5.0"

[[Missings]]
deps = ["DataAPI"]
git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8"
git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
version = "0.4.4"
version = "0.4.5"

[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
Expand All @@ -252,9 +252,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
version = "0.5.3+4"

[[OrderedCollections]]
git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.3.2"
version = "1.3.3"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down Expand Up @@ -375,9 +375,9 @@ version = "1.2.11+18"

[[Zygote]]
deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
git-tree-sha1 = "746c9de7fb87a341c809437007cbd172c4d494b4"
git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.2"
version = "0.6.3"

[[ZygoteRules]]
deps = ["MacroTools"]
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
* Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416).
* Moved GPU CI to use buildkite instead of GitLab
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.

* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397)

## v0.11.2

Expand Down
2 changes: 1 addition & 1 deletion docs/src/models/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,4 @@ Flux.@functor Affine

This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md).

For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md).
21 changes: 17 additions & 4 deletions src/cuda/cudnn.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,20 @@
import CUDA.CUDNN: batchnorm, ∇batchnorm

(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining()))
function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
cache=nothing) where T<:Union{Float32, Float64}

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels"
return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
training=Flux._isactive(BN)))
end

@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) =
batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing)
@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...)
function batchnorm_pullback(Δ)
∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing
end
y, batchnorm_pullback
end
4 changes: 2 additions & 2 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# v0.12 deprecations
@deprecate Dropout(p, dims) Dropout(p; dims=dims)
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing)
@deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing)
@deprecate outdims(f, inputsize) outputsize(f, inputsize)
@deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...)
Expand Down
25 changes: 14 additions & 11 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,32 +134,35 @@ function Base.show(io::IO, l::Dense)
end

"""
Diagonal(in::Integer)
Diagonal(α, β)
Diagonal(sz::Integer...; initα=ones, initβ=zeros)
Create an element-wise linear transformation layer with learnable
vectors `α` and `β`:
Create an element-wise linear layer with learnable
arrays `α` and `β` of size `sz`. The layer performs
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
The input `x` must have size broadcast-compatible with `α` and `β`.
The parameters will be created with the calls
`α = initα(sz)` and `β = initβ(sz)`.
"""
struct Diagonal{T}
α::T
β::T
end

Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(initα(in), initβ(in))
function Diagonal(sz::Integer...;
initα = i -> ones(Float32, i),
initβ = i -> zeros(Float32, i))
Diagonal(initα(sz), initβ(sz))
end

@functor Diagonal

function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
(a::Diagonal)(x) = a.α .* x .+ a.β

function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
print(io, "Diagonal(", size(l.α), ")")
end

"""
Expand Down
Loading

0 comments on commit ff18866

Please sign in to comment.