diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 5f9e116f29..51df2e02fc 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -51,6 +51,10 @@ end Dropout layer. In the forward pass, apply the [`Flux.dropout`](@ref) function on the input. +To apply dropout along certain dimension(s), specify the `dims` keyword. +e.g. `Dropout(p; dims = 3)` will randomly zero out entire channels on WHCN input +(also called 2D dropout). + Does nothing to the input once [`Flux.testmode!`](@ref) is `true`. """ mutable struct Dropout{F,D} @@ -420,4 +424,4 @@ function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") -end \ No newline at end of file +end