From b6dbefb2107bda97d5d288f9d14554cc4fc7c8ca Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Mar 2022 08:13:05 -0500 Subject: [PATCH] Mark initialisations nograd, restrict signatures (#1908) * mark init functions non-diff * add a testing loop * restore a truncated_normal eltype test, and test its defaults too * tweaks * restrict all to integer... size --- src/utils.jl | 58 +++++++++++++++++++++++-------------- test/utils.jl | 79 ++++++++++++++++++++++++++++++++++----------------- 2 files changed, 90 insertions(+), 47 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 0c1876b315..e93b83a89b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -80,10 +80,12 @@ julia> Flux.glorot_uniform(2, 3) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) -glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...) +glorot_uniform(rng::AbstractRNG, dims::Integer...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...))) +glorot_uniform(dims::Integer...) = glorot_uniform(rng_from_array(), dims...) glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...) +ChainRulesCore.@non_differentiable glorot_uniform(::Any...) + """ glorot_normal([rng=GLOBAL_RNG], dims...) @@ -113,10 +115,12 @@ julia> Flux.glorot_normal(3, 2) [1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010. """ -glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) -glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...) +glorot_normal(rng::AbstractRNG, dims::Integer...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...))) +glorot_normal(dims::Integer...) = glorot_normal(rng_from_array(), dims...) glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...) +ChainRulesCore.@non_differentiable glorot_normal(::Any...) + """ kaiming_uniform([rng=GLOBAL_RNG], dims...; gain = √2) @@ -146,14 +150,16 @@ julia> Flux.kaiming_uniform(3, 2) [1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ -function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2) +function kaiming_uniform(rng::AbstractRNG, dims::Integer...; gain = √2) bound = Float32(√3 * gain / sqrt(first(nfan(dims...)))) # fan_in return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound end -kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...) +kaiming_uniform(dims::Integer...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...) kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...) +ChainRulesCore.@non_differentiable kaiming_uniform(::Any...) + """ kaiming_normal([rng=GLOBAL_RNG], dims...; gain = √2) @@ -183,14 +189,16 @@ julia> Flux.kaiming_normal(3, 2) [1] He, Kaiming, et al. "Delving deep into rectifiers: Surpassing human-level performance on imagenet classification." _Proceedings of the IEEE international conference on computer vision_. 2015. """ -function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0) +function kaiming_normal(rng::AbstractRNG, dims::Integer...; gain = √2f0) std = Float32(gain / sqrt(first(nfan(dims...)))) # fan_in return randn(rng, Float32, dims...) .* std end -kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...) +kaiming_normal(dims::Integer...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...) kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...) +ChainRulesCore.@non_differentiable kaiming_normal(::Any...) + """ truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2) @@ -221,7 +229,7 @@ julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100))) [PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf). Department of Scientific Computing website. """ -function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2) +function truncated_normal(rng::AbstractRNG, dims::Integer...; mean = 0, std = 1, lo = -2, hi = 2) norm_cdf(x) = 0.5 * (1 + erf(x/√2)) if (mean < lo - 2 * std) || (mean > hi + 2 * std) @warn "Mean is more than 2 std outside the limits in truncated_normal, so the distribution of values may be inaccurate." maxlog=1 @@ -237,9 +245,11 @@ function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, return xs end -truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...) +truncated_normal(dims::Integer...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...) truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...) +ChainRulesCore.@non_differentiable truncated_normal(::Any...) + """ orthogonal([rng=GLOBAL_RNG], dims...; gain = 1) @@ -307,6 +317,8 @@ end orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...) orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...) +ChainRulesCore.@non_differentiable orthogonal(::Any...) + """ sparse_init([rng=GLOBAL_RNG], dims...; sparsity, std = 0.01) @@ -336,7 +348,7 @@ julia> Flux.sparse_init(3, 2, sparsity=0.1) [1] Martens, J, "Deep learning via Hessian-free optimization" _Proceedings of the 27th International Conference on International Conference on Machine Learning_. 2010. """ -function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01) +function sparse_init(rng::AbstractRNG, dims::Integer...; sparsity, std = 0.01) if length(dims) != 2 throw(ArgumentError("Only 2-dimensional outputs are supported for sparse initialization.")) end @@ -348,9 +360,11 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01) return mapslices(shuffle, sparse_array, dims=1) end -sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...) +sparse_init(dims::Integer...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...) sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...) +ChainRulesCore.@non_differentiable sparse_init(::Any...) + """ identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0) @@ -415,30 +429,32 @@ julia> Flux.identity_init(3,3,2,2) ``` """ # Assume bias -identity_init(cols; gain=1, shift=0) = zeros32(cols) +identity_init(cols::Integer; gain=1, shift=0) = zeros32(cols) # Assume matrix multiplication -identity_init(rows, cols; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift) +identity_init(rows::Integer, cols::Integer; gain=1, shift=0) = circshift(Matrix{Float32}(I * gain, rows,cols), shift) # Assume convolution -function identity_init(dims...; gain=1, shift=0) +function identity_init(dims::Integer...; gain=1, shift=0) nin, nout = dims[end-1], dims[end] centers = map(d -> cld(d, 2), dims[1:end-2]) - weights = zeros32(dims) + weights = zeros32(dims...) for i in 1:min(nin,nout) weights[centers..., i, i] = gain end return circshift(weights, shift) end -identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...) +identity_init(::AbstractRNG, dims::Integer...; kwargs...) = identity_init(dims...; kwargs...) identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...) identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...) -ones32(dims...) = Base.ones(Float32, dims...) -zeros32(dims...) = Base.zeros(Float32, dims...) -rand32(dims...) = Base.rand(Float32, dims...) -randn32(dims...) = Base.randn(Float32, dims...) +ChainRulesCore.@non_differentiable identity_init(::Any...) + +ones32(dims::Integer...) = Base.ones(Float32, dims...) +zeros32(dims::Integer...) = Base.zeros(Float32, dims...) +rand32(dims::Integer...) = Base.rand(Float32, dims...) +randn32(dims::Integer...) = Base.randn(Float32, dims...) """ create_bias(weights, bias, size...) diff --git a/test/utils.jl b/test/utils.jl index c75780e360..d04fd85e06 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,7 +1,7 @@ using Flux using Flux: throttle, nfan, glorot_uniform, glorot_normal, kaiming_normal, kaiming_uniform, orthogonal, truncated_normal, - sparse_init, stack, unstack, batch, unbatch, + sparse_init, identity_init, stack, unstack, batch, unbatch, unsqueeze, params, loadparams! using StatsBase: var, std using Statistics, LinearAlgebra @@ -69,6 +69,37 @@ end @test nfan(2, 3, 4, 50, 60) == (2 * 3 * 4 * 50, 2 * 3 * 4 * 60) #For 3D Conv layer end + @testset "Basics: $init" for init in [ + glorot_uniform, glorot_normal, + kaiming_uniform, kaiming_normal, + orthogonal, + sparse_init, + truncated_normal, + identity_init, + ] + if init == sparse_init + init = (args...) -> sparse_init(args...; sparsity=0.5) + else + # sparse_init is the only one which accepts only matrices: + @test size(init(3)) == (3,) + @test size(init(3, 4, 5)) == (3, 4, 5) + end + @test size(init(3, 4)) == (3, 4) + # only init(size...) is accepted: + @test_throws MethodError size(init((3, 4, 5))) == (3, 4, 5) + + # rng, and currying: + @test size(init(MersenneTwister(1), 3, 4)) == (3, 4) + closure = init(MersenneTwister(1)) + @test size(closure(3, 4)) == (3, 4) + + # eltype, default Float32 + @test eltype(init(3, 4)) == Float32 + + # @non_differentiable + @test gradient(x -> sum(x .* init(3, 4)), 5.0)[1] isa Number + end + @testset "glorot" begin # glorot_uniform and glorot_normal should both yield a kernel with # variance ≈ 2/(fan_in + fan_out) @@ -78,7 +109,6 @@ end fan_in, fan_out = nfan(dims...) σ2 = 2 / (fan_in + fan_out) @test 0.9σ2 < var(v) < 1.1σ2 - @test eltype(v) == Float32 end end end @@ -91,12 +121,10 @@ end σ2 = sqrt(6/n_out) @test -1σ2 < minimum(v) < -0.9σ2 @test 0.9σ2 < maximum(v) < 1σ2 - @test eltype(v) == Float32 v = kaiming_normal(n_in, n_out) σ2 = sqrt(2/n_out) @test 0.9σ2 < std(v) < 1.1σ2 - @test eltype(v) == Float32 end end @@ -125,46 +153,46 @@ end @test_throws ArgumentError sparse_init(100, 100, 100, sparsity=0.1) v = sparse_init(100, 100, sparsity=-0.1) @test sum(v .== 0) == 0 - @test eltype(v) == Float32 v = sparse_init(100, 100, sparsity=1.1) @test sum(v .== 0) == length(v) - @test eltype(v) == Float32 for (n_in, n_out, sparsity, σ) in [(100, 100, 0.25, 0.1), (100, 400, 0.75, 0.01)] expected_zeros = ceil(Integer, n_in * sparsity) v = sparse_init(n_in, n_out, sparsity=sparsity, std=σ) @test all([sum(v[:,col] .== 0) == expected_zeros for col in 1:n_out]) @test 0.9 * σ < std(v[v .!= 0]) < 1.1 * σ - @test eltype(v) == Float32 end end @testset "truncated_normal" begin - size = (100, 100, 100) - for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)] - v = truncated_normal(size; mean = μ, std = σ, lo, hi) + m = truncated_normal(100, 100) + @test minimum(m) ≈ -2 atol = 0.05 # default arguments + @test maximum(m) ≈ 2 atol = 0.05 + @test mean(m) ≈ 0 atol = 0.1 + + size100 = (100, 100, 100) + for (μ, σ, lo, hi) in [(0.0, 1, -2, 3), (1, 2, -4.0, 5.0)] + v = truncated_normal(size100...; mean = μ, std = σ, lo, hi) @test isapprox(mean(v), μ; atol = 1f-1) - @test isapprox(minimum(v), lo; atol = 1f-1) - @test isapprox(maximum(v), hi; atol = 1f-1) - @test eltype(v) == Float32 + @test isapprox(minimum(v), lo; atol = 1f-2) + @test isapprox(maximum(v), hi; atol = 1f-2) + @test eltype(v) == Float32 # despite some Float64 arguments end - for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)] - v = truncated_normal(size...; mean = μ, std = σ, lo, hi) + for (μ, σ, lo, hi) in [(6, 2, -100.0, 100), (-7.0, 10, -100, 100)] + v = truncated_normal(size100...; mean = μ, std = σ, lo, hi) + @test isapprox(mean(v), μ; atol = 1f-1) @test isapprox(std(v), σ; atol = 1f-1) - @test eltype(v) == Float32 end end - @testset "partial_application" begin - big = 1e9 - - partial_ku = kaiming_uniform(gain=big) - @test maximum(partial_ku(8, 8)) > big / 2 - @test maximum(partial_ku(8, 8, gain=1)) < big / 2 + @testset "Partial application" begin + partial_ku = kaiming_uniform(gain=1e9) + @test maximum(partial_ku(8, 8)) > 1e9 / 2 + @test maximum(partial_ku(8, 8, gain=1)) < 1e9 / 2 - partial_kn = kaiming_normal(gain=big) - @test maximum(partial_kn(8, 8)) > big / 2 - @test maximum(partial_kn(8, 8, gain=1)) < big / 2 + partial_kn = kaiming_normal(gain=1e9) + @test maximum(partial_kn(8, 8)) > 1e9 / 2 + @test maximum(partial_kn(8, 8, gain=1)) < 1e9 / 2 partial_si = sparse_init(sparsity=1) @test maximum(partial_si(8, 8)) == 0 @@ -172,7 +200,6 @@ end end @testset "identity_init" begin - import Flux: identity_init @testset "Basic" begin partial = identity_init(gain=3)