diff --git a/Manifest.toml b/Manifest.toml index 926e481317..a22b079f19 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -14,9 +14,9 @@ version = "0.3.3" [[Adapt]] deps = ["LinearAlgebra"] -git-tree-sha1 = "87491f7d03ae1b423a353aff99cf61a45e3c993a" +git-tree-sha1 = "ffcfa2d345aaee0ef3d8346a073d5dd03c983ebe" uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" -version = "3.1.0" +version = "3.2.0" [[Artifacts]] deps = ["Pkg"] @@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" version = "0.3.4+0" [[DataAPI]] -git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84" +git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f" uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" -version = "1.4.0" +version = "1.5.0" [[DataStructures]] deps = ["Compat", "InteractiveUtils", "OrderedCollections"] @@ -134,9 +134,9 @@ version = "0.1.3" [[FillArrays]] deps = ["LinearAlgebra", "Random", "SparseArrays"] -git-tree-sha1 = "8bd8e47ff5d34b20f0aa9641988eb660590008bc" +git-tree-sha1 = "50eabdace27aa27b143f65b65e762bb0112a7708" uuid = "1a297f60-69ca-5386-bcde-b61e274b549b" -version = "0.11.0" +version = "0.11.1" [[FixedPointNumbers]] deps = ["Statistics"] @@ -146,9 +146,9 @@ version = "0.8.4" [[ForwardDiff]] deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"] -git-tree-sha1 = "c26b56e9b9f0687f7ca887f6b6ded03d269e0e35" +git-tree-sha1 = "d48a40c0f54f29a5c8748cfb3225719accc72b77" uuid = "f6369f11-7733-5829-9624-2563aa707210" -version = "0.10.15" +version = "0.10.16" [[Functors]] deps = ["MacroTools"] @@ -227,9 +227,9 @@ version = "0.5.0" [[Missings]] deps = ["DataAPI"] -git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8" +git-tree-sha1 = "f8c673ccc215eb50fcadb285f522420e29e69e1c" uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28" -version = "0.4.4" +version = "0.4.5" [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" @@ -252,9 +252,9 @@ uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e" version = "0.5.3+4" [[OrderedCollections]] -git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db" +git-tree-sha1 = "d45739abcfc03b51f6a42712894a593f74c80a23" uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" -version = "1.3.2" +version = "1.3.3" [[Pkg]] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] @@ -375,9 +375,9 @@ version = "1.2.11+18" [[Zygote]] deps = ["AbstractFFTs", "ChainRules", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"] -git-tree-sha1 = "746c9de7fb87a341c809437007cbd172c4d494b4" +git-tree-sha1 = "52835a83f7c899cfcb95f796d584201812887ea8" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.2" +version = "0.6.3" [[ZygoteRules]] deps = ["MacroTools"] diff --git a/NEWS.md b/NEWS.md index 8b1ef392bf..db7853995e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,7 +11,7 @@ * Add [sparse initialization](https://github.com/FluxML/Flux.jl/pull/1454) as described in [Deep learning via Hessian-free optimization](https://dl.acm.org/doi/abs/10.5555/3104322.3104416). * Moved GPU CI to use buildkite instead of GitLab * New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks. - +* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397) ## v0.11.2 diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 9e4d0783dd..202efd7433 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -218,4 +218,4 @@ Flux.@functor Affine This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). -For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). \ No newline at end of file +For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index fae297f5b7..0494672791 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,7 +1,20 @@ import CUDA.CUDNN: batchnorm, ∇batchnorm -(BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = - BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = Flux.istraining())) +function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, + cache=nothing) where T<:Union{Float32, Float64} + + @assert BN.affine "BatchNorm: only affine=true supported on gpu" + @assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu" + @assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wronng number of channels" + return BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; + cache=cache, alpha=1, beta=0, eps=BN.ϵ, + training=Flux._isactive(BN))) +end -@adjoint batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = - batchnorm(g, b, x, running_mean, running_var, momentum; kw...), Δ -> (∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing) +@adjoint function batchnorm(g, b, x, running_mean, running_var, momentum; kw...) + y = batchnorm(g, b, x, running_mean, running_var, momentum; kw...) + function batchnorm_pullback(Δ) + ∇batchnorm(g, b, x, Δ, running_mean, running_var, momentum; kw...)..., nothing, nothing, nothing + end + y, batchnorm_pullback +end diff --git a/src/deprecations.jl b/src/deprecations.jl index 1722184aa8..78eb55b733 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -1,7 +1,7 @@ # v0.12 deprecations @deprecate Dropout(p, dims) Dropout(p; dims=dims) -@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) -@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) +@deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing) +@deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, true, true, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate outdims(f, inputsize) outputsize(f, inputsize) @deprecate Conv(; weight, bias, activation=identity, kws...) Conv(weight, bias, activation; kws...) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c8d905be46..141a4f4c3d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -134,32 +134,35 @@ function Base.show(io::IO, l::Dense) end """ - Diagonal(in::Integer) + Diagonal(α, β) + Diagonal(sz::Integer...; initα=ones, initβ=zeros) -Create an element-wise linear transformation layer with learnable -vectors `α` and `β`: +Create an element-wise linear layer with learnable +arrays `α` and `β` of size `sz`. The layer performs y = α .* x .+ β -The input `x` must be a array where `size(x, 1) == in`. +The input `x` must have size broadcast-compatible with `α` and `β`. +The parameters will be created with the calls +`α = initα(sz)` and `β = initβ(sz)`. """ struct Diagonal{T} α::T β::T end -Diagonal(in::Integer; initα = ones, initβ = zeros) = - Diagonal(initα(in), initβ(in)) +function Diagonal(sz::Integer...; + initα = i -> ones(Float32, i), + initβ = i -> zeros(Float32, i)) + Diagonal(initα(sz), initβ(sz)) +end @functor Diagonal -function (a::Diagonal)(x) - α, β = a.α, a.β - α.*x .+ β -end +(a::Diagonal)(x) = a.α .* x .+ a.β function Base.show(io::IO, l::Diagonal) - print(io, "Diagonal(", length(l.α), ")") + print(io, "Diagonal(", size(l.α), ")") end """ diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 51df2e02fc..ddf5e922f9 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -18,7 +18,7 @@ for each input, either sets that input to `0` (with probability e.g. `dims=1` applies dropout along columns and `dims=2` along rows. This is used as a regularisation, i.e. it reduces overfitting during training. -If `active` is `false`, it just returns the input `x` +If `active` is `false`, it just returns the input `x`. Warning: when using this function, you have to manually manage the activation state. Usually in fact, dropout is used while training @@ -47,7 +47,7 @@ function dropout_mask(x, p; dims=:) end """ - Dropout(p, dims=:) + Dropout(p; dims=:) Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input. @@ -70,7 +70,7 @@ end function (a::Dropout)(x) _isactive(a) || return x - return dropout(x, a.p; dims = a.dims, active=true) + return dropout(x, a.p; dims=a.dims, active=true) end testmode!(m::Dropout, mode=true) = @@ -108,52 +108,116 @@ function (a::AlphaDropout)(x) α1 = eltype(x)(-λ*α) noise = randn(eltype(x), size(x)) x = @. x*(noise > (1 - a.p)) + α1 * (noise < (1 - a.p)) - A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5 + A = sqrt(a.p + a.p * (1 - a.p) * α1^2) B = -A * α1 * (1 - a.p) x = @. A * x + B return x end -testmode!(m::AlphaDropout, mode = true) = +testmode!(m::AlphaDropout, mode=true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) """ - LayerNorm(h::Integer) + LayerNorm(sz, λ=identity; affine=true, ϵ=1fe-5) A [normalisation layer](https://arxiv.org/abs/1607.06450) designed to be -used with recurrent hidden states of size `h`. Normalises the mean and standard -deviation of each input before applying a per-neuron gain/bias. +used with recurrent hidden states. +The argument `sz` should be an integer or a tuple of integers. +In the forward pass, the layer normalises the mean and standard +deviation of the input, the applied the elementwise activation `λ`. +The input is normalised along the first `length(sz)` dimensions +for tuple `sz`, along the first dimension for integer `sz`. +The input is expected to have first dimensions' size equal to `sz`. + +If `affine=true` also applies a learnable shift and rescaling +as in the [`Diagonal`](@ref) layer. + + +Se also [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`normalise`](@ref). """ -struct LayerNorm{T} - diag::Diagonal{T} +struct LayerNorm{F,D,T,N} + λ::F + diag::D + ϵ::T + size::NTuple{N,Int} + affine::Bool end -LayerNorm(h::Integer) = - LayerNorm(Diagonal(h)) +function LayerNorm(sz, λ=identity; affine=true, ϵ=1f-5) + sz = sz isa Integer ? (sz,) : sz + diag = affine ? Diagonal(sz...) : nothing + return LayerNorm(λ, diag, ϵ, sz, affine) +end @functor LayerNorm -(a::LayerNorm)(x) = a.diag(normalise(x, dims=1)) +function (a::LayerNorm)(x) + x = normalise(x, dims=1:length(a.size), ϵ=a.ϵ) + a.diag === nothing ? a.λ.(x) : a.λ.(a.diag(x)) +end function Base.show(io::IO, l::LayerNorm) - print(io, "LayerNorm(", length(l.diag.α), ")") + print(io, "LayerNorm($(l.size)") + a.λ == identity || print(io, ", $(a.λ)") + hasaffine(l) || print(io, ", affine=false") + print(io, ")") +end + +# For InstanceNorm, GroupNorm, and BatchNorm. +# Compute the statistics on the slices specified by reduce_dims. +# reduce_dims=[1,...,N-2,N] for BatchNorm +# reduce_dims=[1,...,N-2] for InstanceNorm and GroupNorm +function _norm_layer_forward(l, x::AbstractArray{T,N}; reduce_dims, affine_shape) where {T, N} + if !_isactive(l) && l.track_stats # testmode with tracked stats + stats_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + μ = reshape(l.μ, stats_shape) + σ² = reshape(l.σ², stats_shape) + else # trainmode or testmode without tracked stats + μ = mean(x; dims=reduce_dims) + σ² = mean((x .- μ).^2; dims=reduce_dims) + if l.track_stats + ## update moving mean/std + Zygote.ignore() do + mtm = l.momentum + m = prod(size(x, i) for i in reduce_dims) # needed for computing corrected var + μnew = vec(N ∈ reduce_dims ? μ : mean(μ, dims=N)) + σ²new = vec(N ∈ reduce_dims ? σ² : mean(σ², dims=N)) + l.μ = (1-mtm) .* l.μ .+ mtm .* μnew + l.σ² = (1-mtm) .* l.σ² .+ mtm .* (m / (m - one(eltype(l.σ²)))) .* σ²new + end + end + end + if hasaffine(l) + γ = reshape(l.γ, affine_shape) + β = reshape(l.β, affine_shape) + return l.λ.(γ .* (x .- μ) ./ sqrt.(σ² .+ l.ϵ) .+ β) + else + return l.λ.((x .- μ) ./ sqrt.(σ² .+ l.ϵ)) + end end """ - BatchNorm(channels::Integer, σ = identity; - initβ = zeros, initγ = ones, - ϵ = 1e-8, momentum = .1) + BatchNorm(channels::Integer, λ=identity; + initβ=zeros, initγ=ones, + ϵ=1f-5, momentum= 0.1f0) [Batch Normalization](https://arxiv.org/abs/1502.03167) layer. `channels` should be the size of the channel dimension in your data (see below). -Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For +Given an array with `N` dimensions, call the `N-1`th the channel dimension. For a batch of feature vectors this is just the data dimension, for `WHCN` images -it's the usual channel dimension.) +it's the usual channel dimension. + +`BatchNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×D_N` +input slice and normalises the input accordingly. + +If `affine=true`, it also applies a shift and a rescale to the input +through to learnable per-channel bias β and scale γ parameters. -`BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and -shifts them to have a new mean and variance (corresponding to the learnable, -per-channel `bias` and `scale` parameters). +After normalisation, elementwise activation `λ` is applied. + +If `track_stats=true`, accumulates mean and var statistics in training phase +that will be used to renormalize the input in test phase. Use [`testmode!`](@ref) during inference. @@ -167,170 +231,142 @@ m = Chain( softmax) ``` """ -mutable struct BatchNorm{F,V,W,N} +mutable struct BatchNorm{F,V,N,W} + chs::Int # number of channels λ::F # activation function β::V # bias γ::V # scale - μ::W # moving mean - σ²::W # moving std + μ::W # moving mean + σ²::W # moving var ϵ::N momentum::N + affine::Bool + track_stats::Bool active::Union{Bool, Nothing} end -BatchNorm(chs::Integer, λ = identity; - initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = - BatchNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum, nothing) +function BatchNorm(chs::Int, λ=identity; + initβ = i -> zeros(Float32, i), + initγ = i -> ones(Float32, i), + affine=true, track_stats=true, + ϵ=1f-5, momentum=0.1f0) -trainable(bn::BatchNorm) = (bn.β, bn.γ) - -function (BN::BatchNorm)(x) - size(x, ndims(x)-1) == length(BN.β) || - error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - dims = length(size(x)) - channels = size(x, dims-1) - affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) - m = div(prod(size(x)), channels) - γ = reshape(BN.γ, affine_shape...) - β = reshape(BN.β, affine_shape...) - if !_isactive(BN) - μ = reshape(BN.μ, affine_shape...) - σ² = reshape(BN.σ², affine_shape...) - ϵ = BN.ϵ - else - T = eltype(x) - axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) - μ = mean(x, dims = axes) - σ² = sum((x .- μ) .^ 2, dims = axes) ./ m - ϵ = convert(T, BN.ϵ) - # update moving mean/std - mtm = BN.momentum - S = eltype(BN.μ) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :)) - BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :)) - end + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros(Float32, chs) : nothing + σ² = track_stats ? ones(Float32, chs) : nothing - let λ = BN.λ - x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) - λ.(γ .* x̂ .+ β) - end + return BatchNorm(chs, λ, β, γ, + μ, σ², ϵ, momentum, + affine, track_stats, nothing) end @functor BatchNorm +trainable(bn::BatchNorm) = hasaffine(bn) ? (bn.β, bn.γ) : () -testmode!(m::BatchNorm, mode = true) = +function (BN::BatchNorm)(x) + @assert size(x, ndims(x)-1) == BN.chs + N = ndims(x) + reduce_dims = [1:N-2; N] + affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + return _norm_layer_forward(BN, x; reduce_dims, affine_shape) +end + +testmode!(m::BatchNorm, mode=true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::BatchNorm) - print(io, "BatchNorm($(join(size(l.β), ", "))") - (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, "BatchNorm($(l.chs)") + l.λ == identity || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") print(io, ")") end -expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) -mutable struct InstanceNorm{F,V,W,N} +""" + InstanceNorm(channels::Integer, λ=identity; + initβ=zeros, initγ=ones, + affine=false, track_stats=false, + ϵ=1f-5, momentum=0.1f0) + +[Instance Normalization](https://arxiv.org/abs/1607.08022) layer. +`channels` should be the size of the channel dimension in your data (see below). + +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +For `WHCN` images it's the usual channel dimension. + +`InstanceNorm` computes the mean and variance for each `D_1×...×D_{N-2}×1×1` +input slice and normalises the input accordingly. + +If `affine=true`, it also applies a shift and a rescale to the input +through to learnable per-channel bias `β` and scale `γ` parameters. + +If `track_stats=true`, accumulates mean and var statistics in training phase +that will be used to renormalize the input in test phase. + +**Warning**: the defaults for `affine` and `track_stats` used to be `true` +in previous Flux versions (< v0.12). +""" +mutable struct InstanceNorm{F,V,N,W} + chs::Int # number of channels λ::F # activation function β::V # bias γ::V # scale μ::W # moving mean - σ²::W # moving std + σ²::W # moving var ϵ::N momentum::N + affine::Bool + track_stats::Bool active::Union{Bool, Nothing} end -""" - InstanceNorm(channels::Integer, σ = identity; - initβ = zeros, initγ = ones, - ϵ = 1e-8, momentum = .1) +function InstanceNorm(chs::Int, λ=identity; + initβ = i -> zeros(Float32, i), + initγ = i -> ones(Float32, i), + affine=false, track_stats=false, + ϵ=1f-5, momentum=0.1f0) -[Instance Normalization](https://arxiv.org/abs/1607.08022) layer. -`channels` should be the size of the channel dimension in your data (see below). - -Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For -a batch of feature vectors this is just the data dimension, for `WHCN` images -it's the usual channel dimension.) + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros(Float32, chs) : nothing + σ² = track_stats ? ones(Float32, chs) : nothing -`InstanceNorm` computes the mean and variance for each each `W×H×1×1` slice and -shifts them to have a new mean and variance (corresponding to the learnable, -per-channel `bias` and `scale` parameters). - -Use [`testmode!`](@ref) during inference. - -# Examples -```julia -m = Chain( - Dense(28^2, 64), - InstanceNorm(64, relu), - Dense(64, 10), - InstanceNorm(10), - softmax) -``` -""" -InstanceNorm(chs::Integer, λ = identity; - initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = - InstanceNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum, nothing) - -trainable(in::InstanceNorm) = (in.β, in.γ) - -function (in::InstanceNorm)(x) - size(x, ndims(x)-1) == length(in.β) || - error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))") - ndims(x) > 2 || - error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned") - # these are repeated later on depending on the batch size - dims = length(size(x)) - c = size(x, dims-1) - bs = size(x, dims) - affine_shape = ntuple(i->i == ndims(x) - 1 || i == ndims(x) ? size(x, i) : 1, ndims(x)) - m = div(prod(size(x)), c*bs) - γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - - if !_isactive(in) - μ = expand_inst(in.μ, affine_shape) - σ² = expand_inst(in.σ², affine_shape) - ϵ = in.ϵ - else - T = eltype(x) - - ϵ = convert(T, in.ϵ) - axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes) - μ = mean(x, dims = axes) - σ² = mean((x .- μ) .^ 2, dims = axes) - S = eltype(in.μ) - # update moving mean/std - mtm = in.momentum - in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* S.(reshape(μ, (c, bs))), dims = 2), dims=2) - in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (c, bs)))), dims = 2), dims=2) - end - - let λ = in.λ - x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ) - λ.(γ .* x̂ .+ β) - end + return InstanceNorm(chs, λ, β, γ, + μ, σ², ϵ, momentum, + affine, track_stats, nothing) end @functor InstanceNorm +trainable(in::InstanceNorm) = hasaffine(in) ? (in.β, in.γ) : () + +function (l::InstanceNorm)(x) + @assert ndims(x) > 2 + @assert size(x, ndims(x)-1) == l.chs + N = ndims(x) + reduce_dims = 1:N-2 + affine_shape = ntuple(i -> i == N-1 ? size(x, N-1) : 1, N) + return _norm_layer_forward(l, x; reduce_dims, affine_shape) +end -testmode!(m::InstanceNorm, mode = true) = +testmode!(m::InstanceNorm, mode=true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::InstanceNorm) - print(io, "InstanceNorm($(join(size(l.β), ", "))") - (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, "InstanceNorm($(l.chs)") + l.λ == identity || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") print(io, ")") end """ - GroupNorm(chs::Integer, G::Integer, λ = identity; - initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), - ϵ = 1f-5, momentum = 0.1f0) + GroupNorm(channels::Integer, G::Integer, λ=identity; + initβ = (i) -> zeros(Float32, i), + initγ = (i) -> ones(Float32, i), + affine=true, track_stats=false, + ϵ=1f-5, momentum=0.1f0) [Group Normalization](https://arxiv.org/abs/1803.08494) layer. -This layer can outperform Batch Normalization and Instance Normalization. `chs` is the number of channels, the channel dimension of your input. For an array of N dimensions, the `N-1`th index is the channel dimension. @@ -338,90 +374,84 @@ For an array of N dimensions, the `N-1`th index is the channel dimension. `G` is the number of groups along which the statistics are computed. The number of channels must be an integer multiple of the number of groups. -Use [`testmode!`](@ref) during inference. +`channels` should be the size of the channel dimension in your data (see below). -# Examples -```julia -m = Chain(Conv((3,3), 1=>32, leakyrelu;pad = 1), - GroupNorm(32,16)) - # 32 channels, 16 groups (G = 16), thus 2 channels per group used -``` +Given an array with `N > 2` dimensions, call the `N-1`th the channel dimension. +For `WHCN` images it's the usual channel dimension. + +If `affine=true`, it also applies a shift and a rescale to the input +through to learnable per-channel bias `β` and scale `γ` parameters. + +If `track_stats=true`, accumulates mean and var statistics in training phase +that will be used to renormalize the input in test phase. """ -mutable struct GroupNorm{F,V,W,N,T} - G::T # number of groups +mutable struct GroupNorm{F,V,N,W} + chs::Int # number of channels + G::Int # number of groups λ::F # activation function β::V # bias γ::V # scale - μ::W # moving mean - σ²::W # moving std + μ::W # moving mean + σ²::W # moving std ϵ::N momentum::N + affine::Bool + track_stats::Bool active::Union{Bool, Nothing} end -GroupNorm(chs::Integer, G::Integer, λ = identity; - initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = - GroupNorm(G, λ, initβ(chs), initγ(chs), - zeros(G,1), ones(G,1), ϵ, momentum, nothing) - -trainable(gn::GroupNorm) = (gn.β, gn.γ) - -function(gn::GroupNorm)(x) - size(x,ndims(x)-1) == length(gn.β) || error("Group Norm expected $(length(gn.β)) channels, but got $(size(x,ndims(x)-1)) channels") - ndims(x) > 2 || error("Need to pass at least 3 channels for Group Norm to work") - (size(x,ndims(x) -1))%gn.G == 0 || error("The number of groups ($(gn.G)) must divide the number of channels ($(size(x,ndims(x) -1)))") - - dims = length(size(x)) - groups = gn.G - channels = size(x, dims-1) - batches = size(x,dims) - channels_per_group = div(channels,groups) - affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x)) - - # Output reshaped to (W,H...,C/G,G,N) - μ_affine_shape = ntuple(i->i == ndims(x) ? groups : 1, ndims(x) + 1) - - m = prod(size(x)[1:end-2]) * channels_per_group - γ = reshape(gn.γ, affine_shape...) - β = reshape(gn.β, affine_shape...) - - y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches)) - if !_isactive(gn) - og_shape = size(x) - μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1) - σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1) - ϵ = gn.ϵ - else - T = eltype(x) - og_shape = size(x) - axes = [(1:ndims(y)-2)...] # axes to reduce along (all but channels axis) - μ = mean(y, dims = axes) - σ² = mean((y .- μ) .^ 2, dims = axes) - - ϵ = convert(T, gn.ϵ) - # update moving mean/std - mtm = gn.momentum - S = eltype(gn.μ) - gn.μ = mean((1 - mtm) .* gn.μ .+ mtm .* S.(reshape(μ, (groups,batches))),dims=2) - gn.σ² = mean((1 - mtm) .* gn.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², (groups,batches))),dims=2) - end - - let λ = gn.λ - x̂ = (y .- μ) ./ sqrt.(σ² .+ ϵ) - - # Reshape x̂ - x̂ = reshape(x̂,og_shape) - λ.(γ .* x̂ .+ β) - end +@functor GroupNorm +trainable(gn::GroupNorm) = hasaffine(gn) ? (gn.β, gn.γ) : () + +function GroupNorm(chs::Int, G::Int, λ=identity; + initβ = (i) -> zeros(Float32, i), + initγ = (i) -> ones(Float32, i), + affine=true, track_stats=false, + ϵ=1f-5, momentum=0.1f0) + + chs % G == 0 || error("The number of groups ($(G)) must divide the number of channels ($chs)") + + β = affine ? initβ(chs) : nothing + γ = affine ? initγ(chs) : nothing + μ = track_stats ? zeros(Float32, G) : nothing + σ² = track_stats ? ones(Float32, G) : nothing + + return GroupNorm(chs, G, λ, + β, γ, + μ, σ², + ϵ, momentum, + affine, track_stats, nothing) end -@functor GroupNorm +function (gn::GroupNorm)(x) + @assert ndims(x) > 2 + @assert size(x, ndims(x)-1) == gn.chs + N = ndims(x) + sz = size(x) + x = reshape(x, sz[1:N-2]..., sz[N-1]÷gn.G, gn.G, sz[N]) + N = ndims(x) + reduce_dims = 1:N-2 + affine_shape = ntuple(i -> i ∈ (N-1, N-2) ? size(x, i) : 1, N) + x = _norm_layer_forward(gn, x; reduce_dims, affine_shape) + return reshape(x, sz) +end testmode!(m::GroupNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::GroupNorm) - print(io, "GroupNorm($(join(size(l.β), ", "))") - (l.λ == identity) || print(io, ", λ = $(l.λ)") + print(io, "GroupNorm($(l.chs), $(l.G)") + l.λ == identity || print(io, ", $(l.λ)") + hasaffine(l) || print(io, ", affine=false") print(io, ")") end + +""" + hasaffine(l) + +Return `true` if a normalisation layer has trainable shift and +scale parameters, `false` otherwise. + +See [`BatchNorm`](@ref), [`InstanceNorm`](@ref), [`GroupNorm`](@ref), and [`LayerNorm`](@ref). +""" +hasaffine(l::Union{BatchNorm, InstanceNorm, LayerNorm, GroupNorm}) = l.affine \ No newline at end of file diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl index 37a409a2a2..5d1727e984 100644 --- a/test/cuda/cudnn.jl +++ b/test/cuda/cudnn.jl @@ -3,42 +3,42 @@ using Flux: pullback @testset "CUDNN BatchNorm" begin @testset "4D Input" begin - x = Float64.(collect(reshape(1:12, 2, 2, 3, 1))) + x = rand(Float32, 2, 2, 3, 4) m = BatchNorm(3) - cx = gpu(x) - cm = gpu(m) + gx = gpu(x) + gm = gpu(m) y, back = pullback((m, x) -> m(x), m, x) - cy, cback = pullback((m, x) -> m(x), cm, cx) + gy, gback = pullback((m, x) -> m(x), gm, gx) - @test cpu(cy) ≈ y + @test cpu(gy) ≈ y - Δ = randn(size(y)) + Δ = randn(Float32, size(y)) dm, dx = back(Δ) - cdm, cdx = cback(gpu(Δ)) + gdm, gdx = gback(gpu(Δ)) - @test dm[].γ ≈ cpu(cdm[].γ) - @test dm[].β ≈ cpu(cdm[].β) - @test dx ≈ cpu(cdx) + @test dm[].γ ≈ cpu(gdm[].γ) + @test dm[].β ≈ cpu(gdm[].β) + @test dx ≈ cpu(gdx) end @testset "2D Input" begin - x = Float64.(collect(reshape(1:12, 3, 4))) + x = rand(Float32, 3, 4) m = BatchNorm(3) - cx = gpu(x) - cm = gpu(m) + gx = gpu(x) + gm = gpu(m) y, back = pullback((m, x) -> m(x), m, x) - cy, cback = pullback((m, x) -> m(x), cm, cx) + gy, gback = pullback((m, x) -> m(x), gm, gx) - @test cpu(cy) ≈ y + @test cpu(gy) ≈ y - Δ = randn(size(y)) + Δ = randn(Float32, size(y)) dm, dx = back(Δ) - cdm, cdx = cback(gpu(Δ)) + gdm, gdx = gback(gpu(Δ)) - @test dm[].γ ≈ cpu(cdm[].γ) - @test dm[].β ≈ cpu(cdm[].β) - @test dx ≈ cpu(cdx) + @test dm[].γ ≈ cpu(gdm[].γ) + @test dm[].β ≈ cpu(gdm[].β) + @test dx ≈ cpu(gdx) end end diff --git a/test/cuda/layers.jl b/test/cuda/layers.jl index 233a4d69e4..8057aa384d 100644 --- a/test/cuda/layers.jl +++ b/test/cuda/layers.jl @@ -14,48 +14,25 @@ end # `AlphaDropout` throws a compilation error on GPUs, # whereas, the rest are scalar indexing issues. const BROKEN_LAYERS = Union{DepthwiseConv, - AlphaDropout, - InstanceNorm, - GroupNorm} + AlphaDropout} -function gpu_gradtest(name::String, layers::Vector, x_cpu=nothing, args...; test_cpu=true) - isnothing(x_cpu) && error("Missing input to test the layers against.") +function gpu_gradtest(name::String, layers::Vector, x_cpu, args...; + setmode=false, test_cpu=true, rtol=1e-5, atol=1e-5) @testset "$name GPU grad tests" begin for layer in layers @testset "$layer GPU grad test" begin - - # compute output and grad of parameters l_cpu = layer(args...) - ps_cpu = Flux.params(l_cpu) - y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu) - gs_cpu = back_cpu(1f0) - - x_gpu = gpu(x_cpu) - l_gpu = l_cpu |> gpu - ps_gpu = Flux.params(l_gpu) - - if l_gpu isa BROKEN_LAYERS - @test_broken gradient(() -> sum(l_gpu(x_gpu)), ps_gpu) isa Flux.Zygote.Grads + if l_cpu isa BROKEN_LAYERS + l_gpu, x_gpu = l_cpu |> gpu, x_cpu |> gpu + @test_broken gradient(() -> sum(l_gpu(x_gpu)), Flux.params(l_gpu)) isa Flux.Zygote.Grads else - y_gpu, back_gpu = pullback(() -> sum(l_gpu(x_gpu)), ps_gpu) - gs_gpu = back_gpu(1f0) # TODO many layers error out when backprop int 1, should fix - - # compute grad of input - xg_cpu = gradient(x -> sum(l_cpu(x)), x_cpu)[1] - xg_gpu = gradient(x -> sum(l_gpu(x)), x_gpu)[1] - - # test - if test_cpu - @test y_gpu ≈ y_cpu rtol=1e-4 atol=1e-4 - @test Array(xg_gpu) ≈ xg_cpu rtol=1e-4 atol=1e-4 - end - @test gs_gpu isa Flux.Zygote.Grads - for (p_cpu, p_gpu) in zip(ps_cpu, ps_gpu) - @test gs_gpu[p_gpu] isa Flux.CUDA.CuArray - if test_cpu - @test Array(gs_gpu[p_gpu]) ≈ gs_cpu[p_cpu] rtol=1e-4 atol=1e-4 - end - end + gpu_autodiff_test(l_cpu, x_cpu, + test_equal=test_cpu, rtol=rtol, atol=atol) + if setmode + testmode!(l_cpu) + gpu_autodiff_test(l_cpu, x_cpu, + test_equal=test_cpu, rtol=rtol, atol=atol) + end end end end @@ -67,7 +44,7 @@ end ConvNoBias(args...) = Conv(args...; bias=false) ConvTransposeNoBias(args...) = ConvTranspose(args...; bias=false) CrossCorNoBias(args...) = CrossCor(args...; bias=false) -DepthwiseConvNoBias(args...) = DepthwiseConv(args...;bias=false) +DepthwiseConvNoBias(args...) = DepthwiseConv(args...; bias=false) r = rand(Float32, 28, 28, 1, 1) conv_layers = [Conv, ConvNoBias, ConvTranspose, ConvTransposeNoBias, CrossCor, CrossCorNoBias, DepthwiseConv, DepthwiseConvNoBias] gpu_gradtest("Conv", conv_layers, r, (2,2), 1=>3) @@ -79,27 +56,96 @@ adaptive_pooling_layers = [AdaptiveMaxPool, AdaptiveMeanPool] gpu_gradtest("AdaptivePooling", adaptive_pooling_layers, r, (7,7)) dropout_layers = [Dropout, AlphaDropout] -gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu=false) # dropout is not deterministic +gpu_gradtest("Dropout", dropout_layers, r, 0.5f0; test_cpu=false, setmode=true) # dropout is not deterministic -layer_norm = [LayerNorm] -gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 28,28,3,4), 1, test_cpu=false) #TODO fix errors -gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 5,4), 5) +layer_norm = [i -> LayerNorm(i; affine=false), i -> LayerNorm(i; affine=true)] +gpu_gradtest("LayerNorm 1", layer_norm, rand(Float32, 8, 8, 3, 4), 8) +gpu_gradtest("LayerNorm 2", layer_norm, rand(Float32, 8, 8, 3, 4), (8,8)) +gpu_gradtest("LayerNorm 3", layer_norm, rand(Float32, 5, 4), 5) batch_norm = [BatchNorm] -gpu_gradtest("BatchNorm 1", batch_norm, rand(Float32, 28,28,3,4), 3, test_cpu=false) #TODO fix errors -gpu_gradtest("BatchNorm 2", batch_norm, rand(Float32, 5,4), 5) +gpu_gradtest("BatchNorm 3d", batch_norm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode +gpu_gradtest("BatchNorm 2d", batch_norm, rand(Float32, 8, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode +gpu_gradtest("BatchNorm 1d", batch_norm, rand(Float32, 8, 3, 4), 3, setmode=false) # bug in CUDA.jl with gradient in testmode +gpu_gradtest("BatchNorm fullyconn", batch_norm, rand(Float32, 5,4), 5, setmode=false) -instancenorm = [InstanceNorm] -gpu_gradtest("InstanceNorm", instancenorm, r, 1) +instancenorm = [i -> InstanceNorm(i; affine=false), i -> InstanceNorm(i; affine=true)] +gpu_gradtest("InstanceNorm 3d", instancenorm, rand(Float32, 8, 8, 8, 3, 4), 3, setmode=true) +gpu_gradtest("InstanceNorm 2d", instancenorm, rand(Float32, 8, 8, 3, 4), 3, setmode=true) +gpu_gradtest("InstanceNorm 1d", instancenorm, rand(Float32, 8, 3, 4), 3, setmode=true) -groupnorm = [GroupNorm] -gpu_gradtest("GroupNorm", groupnorm, rand(Float32, 28,28,3,1), 3, 1) +groupnorm = [(i, j) -> GroupNorm(i, j; affine=false), (i, j) -> GroupNorm(i, j; affine=true)] +gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, setmode=true) +gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true) +gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true) @testset "function layers" begin - x = rand(3,3) - gpu_gradtest(x -> sum(Flux.normalise(x; dims=1)), x) - gpu_gradtest(x -> sum(Flux.normalise(x; dims=2)), x) - gpu_gradtest(x -> sum(Flux.normalise(x)), x) + x = rand(Float32, 3,3) + gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x) + gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=2)), x) + gpu_autodiff_test(x -> sum(Flux.normalise(x)), x) +end + +@testset "BatchNorm mix stuff" begin + m_cpu = BatchNorm(2) + m_gpu = m_cpu |> gpu + x_cpu = rand(Float32, 3, 2, 2) + x_gpu = x_cpu |> gpu + + ## In :auto mode, track statistics only in gradient contest + μ_cpu = copy(m_cpu.μ) + m_cpu(x_cpu) + @test m_cpu.μ ≈ μ_cpu + gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + @test !(m_cpu.μ ≈ μ_cpu) + + μ_gpu = copy(m_gpu.μ) + m_gpu(x_gpu) + @test m_gpu.μ ≈ μ_gpu + gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + @test !(m_gpu.μ ≈ μ_gpu) + + @test Array(m_gpu.μ) ≈ m_cpu.μ + + ## In testmode, never track statistics + testmode!(m_cpu) + μ_cpu = copy(m_cpu.μ) + m_cpu(x_cpu) + @test m_cpu.μ ≈ μ_cpu + gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + @test m_cpu.μ ≈ μ_cpu + + testmode!(m_gpu) + μ_gpu = copy(m_gpu.μ) + m_gpu(x_gpu) + @test m_gpu.μ ≈ μ_gpu + gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + @test m_gpu.μ ≈ μ_gpu + + ## In trainmode, always track statistics + trainmode!(m_cpu) + μ_cpu = copy(m_cpu.μ) + m_cpu(x_cpu) + @test !(m_cpu.μ ≈ μ_cpu) + μ_cpu = copy(m_cpu.μ) + gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + @test !(m_cpu.μ ≈ μ_cpu) + + trainmode!(m_gpu) + μ_gpu = copy(m_gpu.μ) + m_gpu(x_gpu) + @test !(m_gpu.μ ≈ μ_gpu) + μ_gpu = copy(m_gpu.μ) + gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) + @test !(m_gpu.μ ≈ μ_gpu) + + ## No errors if input type mistmatch + x_cpu = rand(Float64, 3, 2, 2) + x_gpu = x_cpu |> gpu + m_cpu(x_cpu) + gradient(() -> sum(m_cpu(x_cpu)), Flux.params(m_cpu)) + m_gpu(x_gpu) + gradient(() -> sum(m_gpu(x_gpu)), Flux.params(m_gpu)) end @testset "Zeros mapped for $cl" for cl in (Conv, ConvTranspose, CrossCor, DepthwiseConv) diff --git a/test/cuda/losses.jl b/test/cuda/losses.jl index 83377c3a24..0913b0eb6a 100644 --- a/test/cuda/losses.jl +++ b/test/cuda/losses.jl @@ -20,7 +20,7 @@ y = [1, 1, 0.] y = rand(Float32, 3,3) for loss in ALL_LOSSES - gpu_gradtest(loss, x, y) + gpu_autodiff_test(loss, x, y) end end diff --git a/test/cuda/runtests.jl b/test/cuda/runtests.jl index 8c19caed59..6a6722b238 100644 --- a/test/cuda/runtests.jl +++ b/test/cuda/runtests.jl @@ -4,22 +4,7 @@ using Zygote: pullback @info "Testing GPU Support" CUDA.allowscalar(false) - -function gpu_gradtest(f, args...) - args_gpu = gpu.(args) - - l_cpu, back_cpu = pullback((x...) -> f(x...), args...) - g_cpu = back_cpu(1f0)[1] - - l_gpu, back_gpu = pullback((x...) -> f(x...), args_gpu...) - g_gpu = back_gpu(1f0)[1] - - @test l_cpu ≈ l_gpu rtol=1e-4 atol=1e-4 - @test g_gpu isa CuArray - @test g_cpu ≈ collect(g_gpu) rtol=1e-4 atol=1e-4 -end - - +include("test_utils.jl") include("cuda.jl") include("losses.jl") include("layers.jl") diff --git a/test/cuda/test_utils.jl b/test/cuda/test_utils.jl new file mode 100644 index 0000000000..bc0db37474 --- /dev/null +++ b/test/cuda/test_utils.jl @@ -0,0 +1,72 @@ +function check_grad(g_gpu, g_cpu, atol, rtol) + @show g_gpu g_cpu + @test false +end +check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol) = + check_grad(g_gpu[], g_cpu[], atol, rtol) +check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol) = @test true +check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol) = @test g_cpu ≈ g_gpu rtol=rtol atol=atol +check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol) = + @test g_cpu ≈ collect(g_gpu) rtol=rtol atol=atol + +function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol) + for (v1, v2) in zip(g_gpu, g_cpu) + check_grad(v1, v2, atol, rtol) + end +end + +function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol) + for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu)) + @test k1 == k2 + # @show k2 v2 + check_grad(v1, v2, atol, rtol) + end +end + +function gpu_autodiff_test(f_cpu, xs_cpu::Array{Float32}...; + test_equal=true, rtol=1e-4, atol=1e-4) + + check_type(x) = false + check_type(x::Float32) = true + check_type(x::CuArray{Float32}) = true + check_type(x::Array{Float32}) = true + + ### GRADIENT WITH RESPECT TO INPUT ##### + # y_cpu, back_cpu = pullback((f, x...) -> f(x...), f_cpu, xs_cpu...) + y_cpu, back_cpu = pullback((x...) -> f_cpu(x...), xs_cpu...) + @test check_type(y_cpu) + Δ_cpu = size(y_cpu) == () ? randn(Float32) : randn(Float32, size(y_cpu)) + gs_cpu = back_cpu(Δ_cpu) + + f_gpu = f_cpu |> gpu + xs_gpu = gpu.(xs_cpu) + Δ_gpu = Δ_cpu |> gpu + # y_gpu, back_gpu = pullback((f, x...) -> f(x...), f_gpu, xs_gpu...) + y_gpu, back_gpu = pullback((x...) -> f_gpu(x...), xs_gpu...) + @test check_type(y_gpu) + gs_gpu = back_gpu(Δ_gpu) + + if test_equal + @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol + for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu) + check_grad(g_gpu, g_cpu, atol, rtol) + end + end + + ### GRADIENT WITH RESPECT TO f ##### + ps_cpu = Flux.params(f_cpu) + y_cpu, back_cpu = pullback(() -> f_cpu(xs_cpu...), ps_cpu) + gs_cpu = back_cpu(Δ_cpu) + + ps_gpu = Flux.params(f_gpu) + y_gpu, back_gpu = pullback(() -> f_gpu(xs_gpu...), ps_gpu) + gs_gpu = back_gpu(Δ_gpu) + + if test_equal + @test collect(y_cpu) ≈ collect(y_gpu) rtol=rtol atol=atol + @assert length(ps_gpu) == length(ps_cpu) + for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu) + check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol) + end + end +end diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d8771fb76c..c04d1f97d5 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -63,6 +63,11 @@ import Flux: activations @test Flux.Diagonal(2)([1 2]) == [1 2; 1 2] @test Flux.Diagonal(2)([1,2]) == [1,2] @test Flux.Diagonal(2)([1 2; 3 4]) == [1 2; 3 4] + + @test Flux.Diagonal(2)(rand(2,3,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2,3)(rand(2,3,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2,3,4)(rand(2,3,4)) |> size == (2, 3, 4) + @test Flux.Diagonal(2,3)(rand(2,1,4)) |> size == (2, 3, 4) end @testset "Maxout" begin diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 643f0e510b..89c2f4803e 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -61,6 +61,7 @@ end let m = BatchNorm(2), x = [1.0 3.0 5.0; 2.0 4.0 6.0] + @test Flux.hasaffine(m) == true @test length(params(m)) == 2 @test m.β == [0, 0] # initβ(2) @@ -88,8 +89,8 @@ end # 2×1 Array{Float64,2}: # 1.3 # 1.3 - @test m.σ² ≈ .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] - + @test m.σ² ≈ .1 .* var(x, dims=2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.] + x′ = m(x) @test isapprox(x′[1], (1 .- 0.3) / sqrt(1.3), atol = 1.0e-5) end @@ -123,17 +124,19 @@ end m(x) @test (@allocated m(x)) < 100_000_000 end + + @test length(Flux.params(BatchNorm(10))) == 2 + @test length(Flux.params(BatchNorm(10, affine=true))) == 2 + @test length(Flux.params(BatchNorm(10, affine=false))) == 0 end @testset "InstanceNorm" begin - # helper functions - expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) # begin tests - let m = InstanceNorm(2), sizes = (3, 2, 2), + let m = InstanceNorm(2; affine=true, track_stats=true), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) @test length(params(m)) == 2 - x = Float64.(x) + x = Float32.(x) @test m.β == [0, 0] # initβ(2) @test m.γ == [1, 1] # initγ(2) y = evalwgrad(m, x) @@ -159,29 +162,60 @@ end # ∴ update rule with momentum: # (1. - .1) * 0 + .1 * (2. + 8.) / 2 = .5 # (1. - .1) * 0 + .1 * (5. + 11.) / 2 = .8 + N = ndims(x) @test m.μ ≈ [0.5, 0.8] - # momentum * var * num_items / (num_items - 1) + (1 - momentum) * sigma_sq - # julia> reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. - # 2-element Array{Float64,1}: - # 1. - # 1. - @test m.σ² ≈ reshape(mean(.1 .* var(x, dims = 1, corrected=false) .* (3 / 2), dims=3), :) .+ .9 .* 1. - - x′ = m(x) - @test isapprox(x′[1], (1 - 0.5) / sqrt(1. + 1f-5), atol = 1.0e-5) + n = prod(size(x,i) for i in 1:N-2) + corr = n / (n-1) + σ² = var(x, dims=1:N-2, corrected=false) + @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=N)) .+ 0.9 * 1 + + y = m(x) + @test length(m.μ) == 2 + @test length(m.σ²) == 2 + @test y ≈ (x .- reshape(m.μ, 1,2,1)) ./ sqrt.(reshape(m.σ², 1,2,1) .+ 1f-5) atol=1.0e-5 end + # with activation function - let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + let m = InstanceNorm(2, sigmoid; affine=true, track_stats=true), sizes = (3, 2, 2), x = reshape(collect(1:prod(sizes)), sizes) x = Float64.(x) affine_shape = collect(sizes) - affine_shape[1] = 1 + affine_shape[[1,3]] .= 1 + + y = evalwgrad(m, x) + y = m(x) # inference time after a training step + μ = reshape(m.μ, affine_shape...) + σ² = reshape(m.σ², affine_shape...) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + end + + # with activation function + let m = InstanceNorm(2, sigmoid; affine=true, track_stats=false), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + @test Flux.hasaffine(m) == true + @test length(params(m)) == 2 + x = Float64.(x) y = m(x) - @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) + μ = mean(x, dims=1) + σ² = var(x, dims=1, corrected=false) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 end - let m = trainmode!(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), + let m = InstanceNorm(2, sigmoid), sizes = (3, 2, 2), + x = reshape(collect(1:prod(sizes)), sizes) + @test Flux.hasaffine(m) == false + @test length(params(m)) == 0 + + x = Float64.(x) + y = m(x) + μ = mean(x, dims=1) + σ² = var(x, dims=1, corrected=false) + @test y ≈ sigmoid.((x .- μ) ./ sqrt.(σ² .+ m.ϵ)) atol=1.0e-7 + end + + + let m = trainmode!(InstanceNorm(2; affine=true)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -189,7 +223,7 @@ end end # check that μ, σ², and the output are the correct size for higher rank tensors - let m = InstanceNorm(2), sizes = (5, 5, 3, 4, 2, 6), + let m = InstanceNorm(2; affine=true,track_stats=true), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) y = evalwgrad(m, x) @test size(m.μ) == (sizes[end - 1], ) @@ -198,7 +232,7 @@ end end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = trainmode!(InstanceNorm(2)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = trainmode!(InstanceNorm(2; affine=true)), m_bnorm = trainmode!(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -208,17 +242,43 @@ end @test (@allocated m(x)) < 100_000_000 end + @test length(Flux.params(InstanceNorm(10))) == 0 + @test length(Flux.params(InstanceNorm(10, affine=true))) == 2 + @test length(Flux.params(InstanceNorm(10, affine=false))) == 0 +end + +@testset "LayerNorm" begin + x = rand(2,3) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2,3,4) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2,3,4,5) + @test LayerNorm(2)(x) ≈ Flux.normalise(x, dims=1) + x = rand(2) + @test LayerNorm(2, tanh)(x) ≈ tanh.(Flux.normalise(x, dims=1)) + + x = rand(2,3,4,5) + @test LayerNorm((2,3))(x) ≈ Flux.normalise(x, dims=(1,2)) + x = rand(2,3,4,5) + @test LayerNorm((2,3,4))(x) ≈ Flux.normalise(x, dims=1:3) + + m = LayerNorm((2,3,4)) + @test Flux.hasaffine(m) == true + @test length(params(m)) == 2 + m = LayerNorm((2,3,4), affine=false) + @test Flux.hasaffine(m) == false + @test length(params(m)) == 0 end @testset "GroupNorm" begin # begin tests squeeze(x) = dropdims(x, dims = tuple(findall(size(x) .== 1)...)) # To remove all singular dimensions - let m = GroupNorm(4,2), sizes = (3,4,2), + let m = GroupNorm(4,2, track_stats=true), sizes = (3,4,2), x = reshape(collect(1:prod(sizes)), sizes) @test length(params(m)) == 2 - x = Float64.(x) + x = Float32.(x) @test m.β == [0, 0, 0, 0] # initβ(32) @test m.γ == [1, 1, 1, 1] # initγ(32) @@ -250,20 +310,20 @@ end # (1. - .1) * 0 + .1 * (3.5 + 15.5) / 2 = 0.95 # (1. - .1) * 0 + .1 * (9.5 + 21.5) / 2 = 1.55 @test m.μ ≈ [0.95, 1.55] - - # julia> mean(var(reshape(x,3,2,2,2),dims=(1,2)).* .1,dims=2) .+ .9*1. - # 2-element Array{Float64,1}: - # 1.25 - # 1.25 - @test m.σ² ≈ mean(squeeze(var(reshape(x,3,2,2,2),dims=(1,2))).*.1,dims=2) .+ .9*1. - - x′ = m(x) - @test isapprox(x′[1], (1 - 0.95) / sqrt(1.25 + 1f-5), atol = 1.0e-5) + n = prod(size(x)) ÷ m.G ÷ size(x)[end] + corr = n / (n-1) + z = reshape(x,3,2,2,2) + σ² = var(z, dims=(1,2), corrected=false) + @test m.σ² ≈ 0.1*corr*vec(mean(σ², dims=4)) .+ 0.9 * 1 + + y = m(x) + out = (z .- reshape(m.μ, 1,1,2,1)) ./ sqrt.(reshape(m.σ², 1,1,2,1) .+ 1f-5) + @test y ≈ reshape(out, size(x)) atol=1.0e-5 end # with activation function - let m = GroupNorm(4,2, sigmoid), sizes = (3, 4, 2), + let m = GroupNorm(4,2, sigmoid, track_stats=true), sizes = (3, 4, 2), x = reshape(collect(1:prod(sizes)), sizes) - x = Float64.(x) + x = Float32.(x) μ_affine_shape = ones(Int,length(sizes) + 1) μ_affine_shape[end-1] = 2 # Number of groups @@ -278,10 +338,10 @@ end y = m(x) x_ = reshape(x,affine_shape...) out = reshape(sigmoid.((x_ .- reshape(m.μ,μ_affine_shape...)) ./ sqrt.(reshape(m.σ²,μ_affine_shape...) .+ m.ϵ)),og_shape) - @test isapprox(y, out, atol = 1.0e-7) + @test y ≈ out atol=1e-7 end - let m = trainmode!(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), + let m = trainmode!(GroupNorm(2,2, track_stats=true)), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -289,16 +349,16 @@ end end # check that μ, σ², and the output are the correct size for higher rank tensors - let m = GroupNorm(4,2), sizes = (5, 5, 3, 4, 4, 6), + let m = GroupNorm(4,2, track_stats=true), sizes = (5, 5, 3, 4, 4, 6), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = evalwgrad(m, x) - @test size(m.μ) == (m.G,1) - @test size(m.σ²) == (m.G,1) + @test size(m.μ) == (m.G,) + @test size(m.σ²) == (m.G,) @test size(y) == sizes end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = trainmode!(InstanceNorm(4)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), + let IN = trainmode!(InstanceNorm(4; affine=true)), GN = trainmode!(GroupNorm(4,4)), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end