Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add kaiming initialization and relevant docstrings #1243

Merged
merged 2 commits into from
Jun 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# v0.11
* Add [kaiming initialization](https://arxiv.org/abs/1502.01852) methods: `kaiming_uniform` and `kaiming_normal` [https://github.com/FluxML/Flux.jl/pull/1243]
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
Expand Down
124 changes: 122 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
# Arrays
"""
nfan(n_out, n_in=1) -> Tuple
nfan(dims...)
nfan(dims::Tuple)

For a layer characterized by dimensions `dims`, return a tuple `(fan_in, fan_out)`, where `fan_in`
is the number of input neurons connected to an output one, and `fan_out` is the number of output neurons
connected to an input one.

This function is mainly used by weight initializers, e.g., [`kaiming_normal`](@ref Flux.kaiming_normal).

# Examples

```jldoctest
julia> layer = Dense(10, 20)
Dense(10, 20)

julia> Flux.nfan(size(layer.W))
(10, 20)

julia> layer = Conv((3, 3), 2=>10)
Conv((3, 3), 2=>10)

julia> Flux.nfan(size(layer.weight))
(18, 90)
```
"""
nfan() = 1, 1 # fan_in, fan_out
nfan(n) = 1, n # A vector is treated as a n×1 matrix
nfan(n_out, n_in) = n_in, n_out # In case of Dense kernels: arranged as matrices
nfan(dims::Tuple) = nfan(dims...)
nfan(dims...) = prod(dims[1:end-2]) .* (dims[end-1], dims[end]) # In case of convolution kernels

"""
glorot_uniform(dims...)

Return an `Array` of size `dims` containing random variables taken from a uniform
distribution in the interval ``[-x, x]``, where `x = sqrt(24 / sum(dims)) / 2`.
distribution in the interval ``[-x, x]``, where `x = sqrt(6 / (fan_in + fan_out))`.
johnnychen94 marked this conversation as resolved.
Show resolved Hide resolved

This method is described in [1] and also known as Xavier initialization.

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
Expand All @@ -17,14 +47,27 @@ julia> Flux.glorot_uniform(2, 3)
0.601094 -0.57414 -0.814925
0.900868 0.805994 0.057514
```

# See also

* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)

# References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(dims...) = (rand(Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))

"""
glorot_normal(dims...)

Return an `Array` of size `dims` containing random variables taken from a normal
distribution with mean 0 and standard deviation `sqrt(2 / sum(dims))`.
distribution with mean 0 and standard deviation `sqrt(2 / (fan_in + fan_out))`.

This method is described in [1] and also known as Xavier initialization.

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
Expand All @@ -34,9 +77,86 @@ julia> Flux.glorot_normal(3, 2)
0.523935 0.371009
-0.223261 0.188052
```

# See also

* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)

# References

[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(dims...) = randn(Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))

"""
kaiming_uniform(dims...; gain = √2)

Return an `Array` of size `dims` containing random variables taken from a uniform distribution in the
interval `[-x, x]`, where `x = gain * sqrt(3/fan_in)`.

This method is described in [1] and also known as He initialization.

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.kaiming_uniform(3, 2)
3×2 Array{Float32,2}:
0.950413 1.27439
1.4244 -1.28851
-0.907795 0.0909376
```

# See also

* kaiming initialization using normal distribution: [`kaiming_normal`](@ref Flux.kaiming_normal)
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)

# References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_uniform(dims...; gain = √2)
bound = Float32(√3 * gain / sqrt(first(nfan(dims...)))) # fan_in
return (rand(Float32, dims...) .- 0.5f0) .* 2bound
end

"""
kaiming_normal(dims...; gain = √2)

Return an `Array` of size `dims` containing random variables taken from a normal
distribution with mean 0 and standard deviation `gain * sqrt(fan_in)`.

This method is described in [1] and also known as He initialization.

# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.kaiming_normal(3, 2)
3×2 Array{Float32,2}:
0.679107 -0.134854
0.828413 0.586617
-0.353007 0.297336
```

# See also

* kaiming initialization using uniform distribution: [`kaiming_uniform`](@ref Flux.kaiming_uniform)
* glorot initialization using normal distribution: [`glorot_normal`](@ref Flux.glorot_normal)
* glorot initialization using uniform distribution: [`glorot_uniform`](@ref Flux.glorot_uniform)
* calculation of `fan_in` and `fan_out`: [`nfan`](@ref Flux.nfan)

# References

[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_normal(dims...; gain = √2f0)
std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in
return randn(Float32, dims...) .* std
end

ones(T::Type, dims...) = Base.ones(T, dims...)
zeros(T::Type, dims...) = Base.zeros(T, dims...)

Expand Down
24 changes: 21 additions & 3 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal, stack, unstack
using StatsBase: var
using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, stack, unstack
using StatsBase: var, std
using Random
using Test

Expand Down Expand Up @@ -59,7 +59,7 @@ end
@testset "Fan in/out" begin
@test nfan() == (1, 1) #For a constant
@test nfan(100) == (1, 100) #For vector
@test nfan(100, 200) == (200, 100) #For Dense layer
@test nfan(100, 200) == (200, 100) == nfan((100, 200)) #For Dense layer
@test nfan(2, 30, 40) == (2 * 30, 2 * 40) #For 1D Conv layer
@test nfan(2, 3, 40, 50) == (2 * 3 * 40, 2 * 3 * 50) #For 2D Conv layer
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
Expand All @@ -74,9 +74,27 @@ end
fan_in, fan_out = nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
@test eltype(v) == Float32
end
end
end

@testset "kaiming" begin
# kaiming_uniform should yield a kernel in range [-sqrt(6/n_out), sqrt(6/n_out)]
# and kaiming_normal should yield a kernel with stddev ~= sqrt(2/n_out)
for (n_in, n_out) in [(100, 100), (100, 400)]
v = kaiming_uniform(n_in, n_out)
σ2 = sqrt(6/n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2
@test eltype(v) == Float32

v = kaiming_normal(n_in, n_out)
σ2 = sqrt(2/n_out)
@test 0.9σ2 < std(v) < 1.1σ2
@test eltype(v) == Float32
end
end
end

@testset "Params" begin
Expand Down