Skip to content

Add dropout #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
name = "NNlib"
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
version = "0.8.13"
version = "0.8.14"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
4 changes: 4 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ChainRulesCore
import ChainRulesCore: rrule
using Base.Broadcast: broadcasted
using Base.Threads
using Random
using Statistics
using Statistics: mean
using LinearAlgebra
Expand Down Expand Up @@ -40,6 +41,9 @@ for f in ACTIVATIONS
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases

include("dropout.jl")
export dropout, dropout!

include("softmax.jl")
export softmax, softmax!, ∇softmax, ∇softmax!, logsoftmax,
logsoftmax!, ∇logsoftmax, ∇logsoftmax!, logsumexp
Expand Down
156 changes: 156 additions & 0 deletions src/dropout.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@

"""
dropout([rng], A, p; dims=:)

Returns an array in which each element of `A` is either replaced with zero,
with probability `p`, or else multiplied by `1/(1-p)`.

By default every element is treated independently.
With `dims=1`, a choice is made for every value of the 1st index
i.e. each row of a matrix is either zero or not.

Optional first argument is the random number generator used.

# Examples
```
julia> dropout(ones(2, 10), 0.2)
2×10 Matrix{Float64}:
1.25 1.25 0.0 1.25 1.25 1.25 1.25 1.25 1.25 1.25
1.25 1.25 1.25 0.0 1.25 1.25 0.0 1.25 1.25 1.25

julia> mean(dropout(ones(10^4, 5), 0.2), dims=1)
1×5 Matrix{Float64}:
0.998 1.00075 0.99125 0.99575 1.00075

julia> dropout(ones(5, 5), 0.7, dims=1) # whole row the same
5×5 Matrix{Float64}:
3.33333 3.33333 3.33333 3.33333 3.33333
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
3.33333 3.33333 3.33333 3.33333 3.33333
0.0 0.0 0.0 0.0 0.0

julia> mean(dropout(ones(10^4, 5), 0.3, dims=1), dims=1)
1×5 Matrix{Float64}:
1.00571 1.00571 1.00571 1.00571 1.00571
```
"""
dropout(A::AbstractArray, p::Real; dims = :) = dropout(_rng_from_array(A), A, p; dims)

function dropout(rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
T = float(eltype(A))
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
if p > 0
dst = similar(A, T)
pT = convert(real(T), p)
_dropout!(rng, dst, A, pT, dims)
else
# Not so sure we want fast paths... this tries but doesn't guarantee type-stability,
# and the rrule does not have such a fast paths.
convert(AbstractArray{T}, A)
end
end

"""
dropout!(B, A, p; dims=:)

This does exactly `B .= dropout(A, p; dims)`,
or rather, it's the implementation of out-of-place [`dropout`](@ref).
"""
function dropout!(dst::AbstractArray, src::AbstractArray, p::Real; dims=:)
size(dst) == size(src) || throw(DimensionMismatch("dropout! expects output array the same size as input"))
0 <= p <= 1 || throw(ArgumentError("dropout expects a probability 0 <= p <= 1"))
if p > 0
rng = _rng_from_array(A)
pT = convert(real(eltype(dst)), p)
_dropout!(rng, dst, src, pT, dims)
else
# This fast path isn't free, but no concerns about types changing:
copyto!(dst, src)
end
end

# This is the easy case in that we can safely use the output array for random numbers.
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)
T = real(eltype(dst))
val = convert(T, 1/(1-p))
rand!(rng, dst)
## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast!
# dst .= (dst.>p) .* val .* src
_fast_broadcast!(dst, src) do q, x
((real(q)>p) * val) * x
end
dst
end

# For other dims, we we do need to allocate something.
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)
T = real(eltype(dst))
tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
rand!(rng, tmp)
val = convert(T, 1/(1-p))
## One-pass strategy -- faster on GPU
dst .= ((tmp.>p) .* val) .* src
## Two-pass strategy -- slightly faster on some CPUs?
# _fast_broadcast!(tmp) do q
# (q>p) * val
# end
# dst .= tmp .* src
end

# The gradient needs to keep the random choices made, thus store at least a BitArray,
# but the following way turns out to be faster & simpler:
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
T = float(real(eltype(A)))
val = convert(T, 1/(1-p))
keep = if dims isa Colon
similar(A, T)
else
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))
end
rand!(rng, keep)
Y = @. ((keep>p) * val) * A
function dropout_back(Δ)
dY = unthunk(Δ)
dA = @. ((keep>p) * val) * dY
(NoTangent(), NoTangent(), dA, NoTangent())
end
return Y, dropout_back
end
# Possibly TODO: another approach to the gradient would be to copy the RNG
# and then re-generate the same mask, instead of storing it. This saves memory
# and seems about as fast, but needs a method `copy(::CUDA.RNG)` and careful checking.
# https://github.com/FluxML/NNlib.jl/pull/454#issuecomment-1369357402

"""
_fast_broadcast!(f, x, y, z...)

This does `x .= f.(x, y, z...)`, but works around
an issue with broadcasting that prevents SIMD in such cases.
Can be removed once https://github.com/JuliaLang/julia/issues/43153 is fixed.

Not intended for general use. Does not check sizes!
"""
function _fast_broadcast!(f::F, x::Array, yz...) where {F<:Function}
bc = Broadcast.instantiate(Broadcast.broadcasted(f, x, yz...))
@simd ivdep for I in eachindex(bc)
@inbounds x[I] = bc[I]
end
return x
end
function _fast_broadcast!(f::F, x::AbstractArray, yz...) where {F<:Function}
# CUDA does not suffer from this bug
broadcast!(f, x, x, yz...)
end


"""
_rng_from_array(x)

Return the random number generator most appropriate for `x`:
`CUDA.default_rng()` for `CuArray`, else `Random.default_rng()`
"""
_rng_from_array(::AbstractArray) = Random.default_rng()

@non_differentiable _rng_from_array(::Any)
Comment on lines +148 to +156
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just simplifying a bit, I couldn't figure out why there were so many functions. Why make a different choice on 1.6

julia> using Random

julia> Random.default_rng()
MersenneTwister(0x9687b6121c4ccb062f473c9c3c8bccc6)

julia> Random.GLOBAL_RNG
Random._GLOBAL_RNG()

julia> VERSION
v"1.6.0"

compared to master:

julia> using Random

julia> Random.default_rng()
TaskLocalRNG()

julia> Random.GLOBAL_RNG
Random._GLOBAL_RNG()

julia> VERSION
v"1.10.0-DEV.204"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't remember now, but based on FluxML/Flux.jl#1849 (comment) it might've been related to thread safety?

Copy link
Member

@ToucheSir ToucheSir Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cthulhu tells me that rand(...) uses default_rng() on 1.6 as well and it returns a thread-local RNG, so maybe this was much ado about nothing. cc @darsnack if I've missed something though, and I think this function can be public like the Flux one.


96 changes: 96 additions & 0 deletions test/dropout.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
using NNlib, Test, Statistics, Random
using Zygote, StableRNGs, ChainRulesCore

@testset "dropout" begin
# Basics
x1 = randn(Float32, 3, 4)
@test size(@inferred dropout(x1, 0.1)) == (3, 4)
@test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)
@test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)
@test eltype(dropout(x1, 0.1)) == Float32
@test eltype(dropout(x1, 0.1; dims=1)) == Float32
@test eltype(dropout(x1, 0.1; dims=(1,2))) == Float32

rng = Random.default_rng()
@test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)

# Values
@test dropout(x1, 0) == x1
@test dropout(x1.+0im, 0) == x1
@test dropout(x1, 1) == zero.(x1)
@test dropout(x1.+im, 1) == zero.(x1)

d45 = dropout(trues(100, 100, 100), 0.45)
@test mean(d45) ≈ 1 atol=1e-2
dpi2 = dropout(fill(pi, 1000), 0.2)
@test sort(unique(dpi2)) ≈ [0, 5pi/4]
d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)
@test sort(unique(vec(d33))) ≈ [0, 3/(1-0.3)]

# Complex -- not worth too much optimisation, but should work!
x2 = [1.0+0im,2.0+1im,3.0+3im] # from Flux's tests
@test dropout(x2, 0.5) isa Vector{ComplexF64}
@test dropout(x2, 0.5; dims=1) isa Vector{ComplexF64}

# Gradient rule
y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)
dx = back(fill(3, 1000, 2))[3]
@test !all(iszero, dx[:,2]) # this is why we save the random choices
@test sort(unique(vec(dx))) ≈ [0, 3/(1-0.45)]

y2, back2 = rrule(dropout, rng, x2, 0.5)
@test y2 isa Vector{ComplexF64}
@test back2(one.(y2))[3] isa Vector{ComplexF64}

@testset "Zygote" begin
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}

# p=0 & p=1
@test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4)
@test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4)

# Second order
f1(x) = sum(dropout(x, 0.5))
@test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse
@test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)
end

# Errors
@test_throws ArgumentError dropout(x1, -1)
@test_throws ArgumentError dropout(x1, 2)
end

@testset "dropout + CUDA" begin
# Basics
x1 = CUDA.randn(3, 4)
@test size(@inferred dropout(x1, 0.1)) == (3, 4)
@test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)
@test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)

rng = CUDA.default_rng()
@test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)

# Values
d45 = dropout(CUDA.ones(100, 100, 100), 0.45)
@test mean(d45) ≈ 1 atol=1e-2
dpi2 = dropout(CUDA.fill(1f0 * pi, 1000), 0.2)
@test sort(unique(Array(dpi2))) ≈ [0, 5pi/4]
d33 = dropout(CUDA.fill(3f0, 10, 1000), 0.3, dims=2)
@test sort(unique(vec(Array(d33)))) ≈ [0, 3/(1-0.3)]

# Gradient rule
y, back = rrule(dropout, rng, hcat(CUDA.ones(1000), CUDA.zeros(1000)), 0.45)
dx = back(CUDA.fill(3f0, 1000, 2))[3]
@test !all(iszero, dx[:,2]) # this is why we save the random choices
@test sort(unique(vec(Array(dx)))) ≈ [0, 3/(1-0.45)]

@testset "Zygote" begin
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa CuArray{Float32}
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa CuArray{Float32}
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa CuArray{Float32}
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ include("test_utils.jl")
include("ctc.jl")
end

@testset "Dropout" begin
include("dropout.jl")
end

@testset "Fold/Unfold" begin
include("fold.jl")
end
Expand Down