Skip to content

Commit

Permalink
Mark initialisations nograd, restrict signatures (#1908)
Browse files Browse the repository at this point in the history
* mark init functions non-diff

* add a testing loop

* restore a truncated_normal eltype test, and test its defaults too

* tweaks

* restrict all to integer... size
  • Loading branch information
mcabbott authored Mar 20, 2022
1 parent 2aa2a26 commit b6dbefb
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 47 deletions.
58 changes: 37 additions & 21 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,12 @@ julia> Flux.glorot_uniform(2, 3)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG, dims::Integer...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims::Integer...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

ChainRulesCore.@non_differentiable glorot_uniform(::Any...)

"""
glorot_normal([rng=GLOBAL_RNG], dims...)
Expand Down Expand Up @@ -113,10 +115,12 @@ julia> Flux.glorot_normal(3, 2)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG, dims::Integer...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims::Integer...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

ChainRulesCore.@non_differentiable glorot_normal(::Any...)

"""
kaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2)
Expand Down Expand Up @@ -146,14 +150,16 @@ julia> Flux.kaiming_uniform(3, 2)
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_uniform(rng::AbstractRNG, dims...; gain = 2)
function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain = 2)
bound = Float32(3 * gain / sqrt(first(nfan(dims...)))) # fan_in
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_uniform(::Any...)

"""
kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2)
Expand Down Expand Up @@ -183,14 +189,16 @@ julia> Flux.kaiming_normal(3, 2)
[1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015.
"""
function kaiming_normal(rng::AbstractRNG, dims...; gain = 2f0)
function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain = 2f0)
std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable kaiming_normal(::Any...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)
Expand Down Expand Up @@ -221,7 +229,7 @@ julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1
Expand All @@ -237,9 +245,11 @@ function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2,
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(dims::Integer...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable truncated_normal(::Any...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Expand Down Expand Up @@ -307,6 +317,8 @@ end
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable orthogonal(::Any...)

"""
sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01)
Expand Down Expand Up @@ -336,7 +348,7 @@ julia> Flux.sparse_init(3, 2, sparsity=0.1)
[1] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010.
"""
function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01)
if length(dims) != 2
throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization."))
end
Expand All @@ -348,9 +360,11 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(dims::Integer...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

ChainRulesCore.@non_differentiable sparse_init(::Any...)

"""
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)
Expand Down Expand Up @@ -415,30 +429,32 @@ julia> Flux.identity_init(3,3,2,2)
```
"""
# Assume bias
identity_init(cols; gain=1, shift=0) = zeros32(cols)
identity_init(cols::Integer; gain=1, shift=0) = zeros32(cols)

# Assume matrix multiplication
identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)
identity_init(rows::Integer, cols::Integer; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift)

# Assume convolution
function identity_init(dims...; gain=1, shift=0)
function identity_init(dims::Integer...; gain=1, shift=0)
nin, nout = dims[end-1], dims[end]
centers = map(d -> cld(d, 2), dims[1:end-2])
weights = zeros32(dims)
weights = zeros32(dims...)
for i in 1:min(nin,nout)
weights[centers..., i, i] = gain
end
return circshift(weights, shift)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
zeros32(dims...) = Base.zeros(Float32, dims...)
rand32(dims...) = Base.rand(Float32, dims...)
randn32(dims...) = Base.randn(Float32, dims...)
ChainRulesCore.@non_differentiable identity_init(::Any...)

ones32(dims::Integer...) = Base.ones(Float32, dims...)
zeros32(dims::Integer...) = Base.zeros(Float32, dims...)
rand32(dims::Integer...) = Base.rand(Float32, dims...)
randn32(dims::Integer...) = Base.randn(Float32, dims...)

"""
create_bias(weights, bias, size...)
Expand Down
79 changes: 53 additions & 26 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, batch, unbatch,
sparse_init, identity_init, stack, unstack, batch, unbatch,
unsqueeze, params, loadparams!
using StatsBase: var, std
using Statistics, LinearAlgebra
Expand Down Expand Up @@ -69,6 +69,37 @@ end
@test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer
end

@testset "Basics: $init" for init in [
glorot_uniform, glorot_normal,
kaiming_uniform, kaiming_normal,
orthogonal,
sparse_init,
truncated_normal,
identity_init,
]
if init == sparse_init
init = (args...) -> sparse_init(args...; sparsity=0.5)
else
# sparse_init is the only one which accepts only matrices:
@test size(init(3)) == (3,)
@test size(init(3, 4, 5)) == (3, 4, 5)
end
@test size(init(3, 4)) == (3, 4)
# only init(size...) is accepted:
@test_throws MethodError size(init((3, 4, 5))) == (3, 4, 5)

# rng, and currying:
@test size(init(MersenneTwister(1), 3, 4)) == (3, 4)
closure = init(MersenneTwister(1))
@test size(closure(3, 4)) == (3, 4)

# eltype, default Float32
@test eltype(init(3, 4)) == Float32

# @non_differentiable
@test gradient(x -> sum(x .* init(3, 4)), 5.0)[1] isa Number
end

@testset "glorot" begin
# glorot_uniform and glorot_normal should both yield a kernel with
# variance ≈ 2/(fan_in + fan_out)
Expand All @@ -78,7 +109,6 @@ end
fan_in, fan_out = nfan(dims...)
σ2 = 2 / (fan_in + fan_out)
@test 0.9σ2 < var(v) < 1.1σ2
@test eltype(v) == Float32
end
end
end
Expand All @@ -91,12 +121,10 @@ end
σ2 = sqrt(6/n_out)
@test -1σ2 < minimum(v) < -0.9σ2
@test 0.9σ2 < maximum(v) < 1σ2
@test eltype(v) == Float32

v = kaiming_normal(n_in, n_out)
σ2 = sqrt(2/n_out)
@test 0.9σ2 < std(v) < 1.1σ2
@test eltype(v) == Float32
end
end

Expand Down Expand Up @@ -125,54 +153,53 @@ end
@test_throws ArgumentError sparse_init(100, 100, 100, sparsity=0.1)
v = sparse_init(100, 100, sparsity=-0.1)
@test sum(v .== 0) == 0
@test eltype(v) == Float32
v = sparse_init(100, 100, sparsity=1.1)
@test sum(v .== 0) == length(v)
@test eltype(v) == Float32

for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)]
expected_zeros = ceil(Integer, n_in * sparsity)
v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ)
@test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out])
@test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ
@test eltype(v) == Float32
end
end

@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
m = truncated_normal(100, 100)
@test minimum(m) -2 atol = 0.05 # default arguments
@test maximum(m) 2 atol = 0.05
@test mean(m) 0 atol = 0.1

size100 = (100, 100, 100)
for (μ, σ, lo, hi) in [(0.0, 1, -2, 3), (1, 2, -4.0, 5.0)]
v = truncated_normal(size100...; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(minimum(v), lo; atol = 1f-1)
@test isapprox(maximum(v), hi; atol = 1f-1)
@test eltype(v) == Float32
@test isapprox(minimum(v), lo; atol = 1f-2)
@test isapprox(maximum(v), hi; atol = 1f-2)
@test eltype(v) == Float32 # despite some Float64 arguments
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
for (μ, σ, lo, hi) in [(6, 2, -100.0, 100), (-7.0, 10, -100, 100)]
v = truncated_normal(size100...; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(std(v), σ; atol = 1f-1)
@test eltype(v) == Float32
end
end

@testset "partial_application" begin
big = 1e9

partial_ku = kaiming_uniform(gain=big)
@test maximum(partial_ku(8, 8)) > big / 2
@test maximum(partial_ku(8, 8, gain=1)) < big / 2
@testset "Partial application" begin
partial_ku = kaiming_uniform(gain=1e9)
@test maximum(partial_ku(8, 8)) > 1e9 / 2
@test maximum(partial_ku(8, 8, gain=1)) < 1e9 / 2

partial_kn = kaiming_normal(gain=big)
@test maximum(partial_kn(8, 8)) > big / 2
@test maximum(partial_kn(8, 8, gain=1)) < big / 2
partial_kn = kaiming_normal(gain=1e9)
@test maximum(partial_kn(8, 8)) > 1e9 / 2
@test maximum(partial_kn(8, 8, gain=1)) < 1e9 / 2

partial_si = sparse_init(sparsity=1)
@test maximum(partial_si(8, 8)) == 0
@test maximum(partial_si(8, 8, sparsity=0)) > 0
end

@testset "identity_init" begin
import Flux: identity_init

@testset "Basic" begin
partial = identity_init(gain=3)
Expand Down

0 comments on commit b6dbefb

Please sign in to comment.