Skip to content

Commit

Permalink
nan & complex fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 3, 2023
1 parent 4111cb2 commit 814bfff
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,25 @@ end

# This is the easy case in that we can safely use the output array for random numbers.
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims::Colon)
val = convert(eltype(dst), 1/(1-p))
T = real(eltype(dst))
val = convert(T, 1/(1-p))
rand!(rng, dst)
## This is what we want, but it hits a SIMD bug, solved by _fast_broadcast!
# dst .= (dst.>p) .* val .* src
_fast_broadcast!(dst, src) do q, x
(q>p) * val * x
((real(q)>p) * val) * x
end
dst
end

# For other dims, we we do need to allocate something.
function _dropout!(rng::AbstractRNG, dst::AbstractArray, src::AbstractArray, p::Real, dims)
tmp = similar(dst, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
T = real(eltype(dst))
tmp = similar(dst, T, ntuple(d -> d in dims ? size(src,d) : 1, ndims(src)))
rand!(rng, tmp)
val = convert(eltype(dst), 1/(1-p))
val = convert(T, 1/(1-p))
## One-pass strategy -- faster on GPU
dst .= (tmp.>p) .* val .* src
dst .= ((tmp.>p) .* val) .* src
## Two-pass strategy -- slightly faster on some CPUs?
# _fast_broadcast!(tmp) do q
# (q>p) * val
Expand All @@ -99,18 +101,18 @@ end
# The gradient needs to keep the random choices made, thus store at least a BitArray,
# but the following way turns out to be faster & simpler:
function ChainRulesCore.rrule(::typeof(dropout), rng::AbstractRNG, A::AbstractArray, p::Real; dims = :)
T = float(eltype(A))
T = float(real(eltype(A)))
val = convert(T, 1/(1-p))
keep = if dims isa Colon
similar(A, T)
else
similar(A, T, ntuple(d -> d in dims ? size(A,d) : 1, ndims(A)))
end
rand!(rng, keep)
Y = @. (keep>p) * A * val
Y = @. ((keep>p) * val) * A
function dropout_back(Δ)
dY = unthunk(Δ)
dA = @. (keep>p) * dY * val
dA = @. ((keep>p) * val) * dY
(NoTangent(), NoTangent(), dA, NoTangent())
end
return Y, dropout_back
Expand Down
54 changes: 54 additions & 0 deletions test/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,81 @@ using Zygote, StableRNGs, ChainRulesCore
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)

# Values
@test dropout(x1, 0) == x1
@test dropout(x1.+0im, 0) == x1
@test dropout(x1, 1) == zero.(x1)
@test dropout(x1.+im, 1) == zero.(x1)

d45 = dropout(trues(100, 100, 100), 0.45)
@test mean(d45) 1 atol=1e-2
dpi2 = dropout(fill(pi, 1000), 0.2)
@test sort(unique(dpi2)) [0, 5pi/4]
d33 = dropout(fill(3, 10, 1000), 0.3, dims=2)
@test sort(unique(vec(d33))) [0, 3/(1-0.3)]

# Complex -- not worth too much optimisation, but should work!
x2 = [1.0+0im,2.0+1im,3.0+3im] # from Flux's tests
@test dropout(x, 0.5) isa Vector{ComplexF64}
@test dropout(x, 0.5; dims=1) isa Vector{ComplexF64}

# Gradient rule
y, back = rrule(dropout, rng, hcat(trues(1000), falses(1000)), 0.45)
dx = back(fill(3, 1000, 2))[3]
@test !all(iszero, dx[:,2]) # this is why we save the random choices
@test sort(unique(vec(dx))) [0, 3/(1-0.45)]

y2, back2 = rrule(dropout, rng, x2, 0.5)
@test y2 isa Vector{ComplexF64}
@test back2(one.(y2))[3] isa Vector{ComplexF64}

@testset "Zygote" begin
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa Matrix{Float32}
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa Matrix{Float32}
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa Matrix{Float32}

# p=0 & p=1
@test Zygote.gradient(x -> sum(dropout(x, 0)), x1)[1] == ones(3,4)
@test Zygote.gradient(x -> sum(dropout(x, 1)), x1)[1] == zeros(3,4)

# Second order
f1(x) = sum(dropout(x, 0.5))
@test_broken Zygote.hessian(f1, [1.0,2.0,3.0]) == zeros(3, 3) # forward over reverse
@test Zygote.hessian_reverse(f1, [1.0,2.0,3.0]) == zeros(3, 3)
end

# Errors
@test_throws ArgumentError dropout(x1, -1)
@test_throws ArgumentError dropout(x1, 2)
end

@testset "dropout + CUDA" begin
# Basics
x1 = CUDA.randn(3, 4)
@test size(@inferred dropout(x1, 0.1)) == (3, 4)
@test size(@inferred dropout(x1, 0.2; dims=2)) == (3, 4)
@test size(@inferred dropout(x1, 0.3; dims=(1,2))) == (3, 4)

rng = CUDA.default_rng()
@test size(@inferred dropout(rng, x1, 0.1)) == (3, 4)
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)

# Values
d45 = dropout(CUDA.ones(100, 100, 100), 0.45)
@test mean(d45) 1 atol=1e-2
dpi2 = dropout(CUDA.fill(1f0 * pi, 1000), 0.2)
@test sort(unique(Array(dpi2))) [0, 5pi/4]
d33 = dropout(CUDA.fill(3f0, 10, 1000), 0.3, dims=2)
@test sort(unique(vec(Array(d33)))) [0, 3/(1-0.3)]

# Gradient rule
y, back = rrule(dropout, rng, hcat(CUDA.ones(1000), CUDA.zeros(1000)), 0.45)
dx = back(CUDA.fill(3f0, 1000, 2))[3]
@test !all(iszero, dx[:,2]) # this is why we save the random choices
@test sort(unique(vec(Array(dx)))) [0, 3/(1-0.45)]

@testset "Zygote" begin
@test Zygote.gradient(x -> sum(dropout(x, 0.3)), x1)[1] isa CuArray{Float32}
@test Zygote.gradient(x -> sum(dropout(rng, x, 0.3)), x1)[1] isa CuArray{Float32}
@test Zygote.gradient(x -> sum(dropout(x, 0.3, dims=1)), x1)[1] isa CuArray{Float32}
end
end

0 comments on commit 814bfff

Please sign in to comment.