Skip to content

Commit

Permalink
Merge pull request #1849 from darsnack/darsnack/dropout-rng
Browse files Browse the repository at this point in the history
Add RNG support for Dropout/AlphaDropout
  • Loading branch information
darsnack authored Jan 27, 2022
2 parents 7467e6b + f922c16 commit 8d3b8d3
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 95 deletions.
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Flux Release Notes

## v0.12.10
* `Dropout`/`AlphaDropout` now supports [user-specified RNGs](https://github.com/FluxML/Flux.jl/pull/1838)

## v0.12.9
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
Expand Down
15 changes: 14 additions & 1 deletion src/functor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ end
struct FluxCUDAAdaptor end
adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x)
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.FillArrays.AbstractFill) = CUDA.cu(collect(x))
if VERSION >= v"1.7"
adapt_storage(to::FluxCUDAAdaptor, x::Random.TaskLocalRNG) = CUDA.default_rng()
else
adapt_storage(to::FluxCUDAAdaptor, x::Random._GLOBAL_RNG) = CUDA.default_rng()
end
adapt_storage(to::FluxCUDAAdaptor, x::CUDA.RNG) = x
adapt_storage(to::FluxCUDAAdaptor, x::AbstractRNG) =
error("Cannot map RNG of type $(typeof(x)) to GPU. GPU execution only supports Random.default_rng().")

# TODO: figure out the correct design for OneElement
adapt_storage(to::FluxCUDAAdaptor, x::Zygote.OneElement) = CUDA.cu(collect(x))
Expand All @@ -109,6 +117,8 @@ adapt_storage(to::FluxCPUAdaptor, x::Zygote.FillArrays.AbstractFill) = x
adapt_storage(to::FluxCPUAdaptor, x::T) where T <: CUDA.CUSPARSE.CUDA.CUSPARSE.AbstractCuSparseMatrix = adapt(Array, x)
adapt_storage(to::FluxCPUAdaptor, x::Zygote.OneElement) = x
adapt_storage(to::FluxCPUAdaptor, x::AbstractSparseArray) = x
adapt_storage(to::FluxCPUAdaptor, x::CUDA.RNG) = Random.default_rng()
adapt_storage(to::FluxCPUAdaptor, x::AbstractRNG) = x

Zygote.@adjoint function Array(x::CUDA.CuArray)
Array(x), d -> (CUDA.cu(d),)
Expand Down Expand Up @@ -149,6 +159,9 @@ _isbitsarray(::AbstractArray{<:Number}) = true
_isbitsarray(::AbstractArray{T}) where T = isbitstype(T)
_isbitsarray(x) = false

_isleaf(::AbstractRNG) = true
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

"""
gpu(x)
Expand All @@ -174,7 +187,7 @@ CuArray{Float32, 2}
"""
function gpu(x)
check_use_cuda()
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isbitsarray) : x
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
end

function check_use_cuda()
Expand Down
60 changes: 42 additions & 18 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)

"""
dropout(x, p; dims=:, active=true)
dropout([rng = rng_from_array(x)], x, p; dims=:, active=true)
The dropout function. If `active` is `true`,
for each input, either sets that input to `0` (with probability
Expand All @@ -20,6 +20,9 @@ This is used as a regularisation, i.e. it reduces overfitting during training.
If `active` is `false`, it just returns the input `x`.
Specify `rng` for custom RNGs instead of the default RNG.
Note that custom RNGs are only supported on the CPU.
Warning: when using this function, you have to manually manage the activation
state. Usually in fact, dropout is used while training
but is deactivated in the inference phase. This can be
Expand All @@ -28,49 +31,63 @@ automatically managed using the [`Dropout`](@ref) layer instead of the
The [`Dropout`](@ref) layer is what you should use in most scenarios.
"""
function dropout(x, p; dims=:, active::Bool=true)
function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x
y = dropout_mask(x, p, dims=dims)
y = dropout_mask(rng, x, p, dims=dims)
return x .* y
end
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

@adjoint function dropout(x, p; dims=:, active::Bool=true)
@adjoint function dropout(rng, x, p; dims=:, active::Bool=true)
active || return x, Δ -> (Δ, nothing)
y = dropout_mask(x, p, dims=dims)
return x .* y, Δ ->.* y, nothing)
y = dropout_mask(rng, x, p, dims=dims)
return x .* y, Δ -> (nothing, Δ .* y, nothing)
end

function dropout_mask(x, p; dims=:)
y = rand!(similar(x, _dropout_shape(x, dims)))
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) =
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropout_mask only support CUDA.RNG for CuArrays."))
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end

"""
Dropout(p; dims=:)
Dropout(p; dims=:, rng = rng_from_array())
Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input.
To apply dropout along certain dimension(s), specify the `dims` keyword.
e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input
(also called 2D dropout).
Specify `rng` to use a custom RNG instead of the default.
Custom RNGs are only supported on the CPU.
Does nothing to the input once [`Flux.testmode!`](@ref) is `true`.
"""
mutable struct Dropout{F,D}
mutable struct Dropout{F,D,R<:AbstractRNG}
p::F
dims::D
active::Union{Bool, Nothing}
rng::R
end
Dropout(p, dims, active) = Dropout(p, dims, active, rng_from_array())

function Dropout(p; dims=:)
function Dropout(p; dims=:, rng = rng_from_array())
@assert 0 p 1
Dropout(p, dims, nothing)
Dropout(p, dims, nothing, rng)
end

@functor Dropout

trainable(a::Dropout) = ()

function (a::Dropout)(x)
_isactive(a) || return x
return dropout(x, a.p; dims=a.dims, active=true)
return dropout(a.rng, x, a.p; dims=a.dims, active=true)
end

testmode!(m::Dropout, mode=true) =
Expand All @@ -83,7 +100,7 @@ function Base.show(io::IO, d::Dropout)
end

"""
AlphaDropout(p)
AlphaDropout(p; rng = rng_from_array())
A dropout layer. Used in
[Self-Normalizing Neural Networks](https://arxiv.org/abs/1706.02515).
Expand All @@ -92,14 +109,21 @@ remain the same as before.
Does nothing to the input once [`testmode!`](@ref) is true.
"""
mutable struct AlphaDropout{F}
mutable struct AlphaDropout{F,R<:AbstractRNG}
p::F
active::Union{Bool, Nothing}
function AlphaDropout(p, active = nothing)
rng::R
function AlphaDropout(p, active, rng)
@assert 0 p 1
new{typeof(p)}(p, active)
new{typeof(p), typeof(rng)}(p, active, rng)
end
end
AlphaDropout(p, active) = AlphaDropout(p, active, rng_from_array())
AlphaDropout(p; rng = rng_from_array()) = AlphaDropout(p, nothing, rng)

@functor AlphaDropout

trainable(a::AlphaDropout) = ()

function (a::AlphaDropout)(x::AbstractArray{T}) where T
_isactive(a) || return x
Expand All @@ -111,7 +135,7 @@ function (a::AlphaDropout)(x::AbstractArray{T}) where T
A = T(inv(sqrt((1 - p) * (1 + p * α′^2))))
B = T(-A * α′ * p)

noise = rand!(similar(x))
noise = rand!(a.rng, similar(x))
return A .* ifelse.(noise .> p, x, α′) .+ B
end

Expand Down
19 changes: 19 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of con
ofeltype(x, y) = convert(float(eltype(x)), y)
epseltype(x) = eps(float(eltype(x)))

"""
rng_from_array([x])
Create an instance of the RNG most appropriate for `x`.
The current defaults are:
- `x isa AbstractArray`
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is >= 1.7: `Random.default_rng()`
- `x isa CuArray`: `CUDA.default_rng()`
When `x` is unspecified, it is assumed to be a `AbstractArray`.
"""
rng_from_array(::AbstractArray) = rng_from_array()
rng_from_array(::CuArray) = CUDA.default_rng()
if VERSION >= v"1.7"
rng_from_array() = Random.default_rng()
else
rng_from_array() = Random.GLOBAL_RNG
end

"""
glorot_uniform([rng=GLOBAL_RNG], dims...)
Expand Down
10 changes: 10 additions & 0 deletions test/cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,13 @@ end
end
end
end

@testset "Dropout RNGs" begin
@test_throws ArgumentError Flux.dropout(MersenneTwister(), CUDA.rand(Float32, 2, 3), 0.1)
@testset for layer in (Dropout, AlphaDropout)
m = layer(0.1; rng = MersenneTwister(123))
@test_throws ErrorException gpu(m)
m = layer(0.1; rng = CUDA.default_rng())
@test gpu(m).rng isa CUDA.RNG
end
end
1 change: 1 addition & 0 deletions test/cuda/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Flux, Test, CUDA
using Zygote
using Zygote: pullback
using Random

@info "Testing GPU Support"
CUDA.allowscalar(false)
Expand Down
Loading

0 comments on commit 8d3b8d3

Please sign in to comment.