Skip to content

Commit

Permalink
Fixed RNGs, updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Feb 18, 2022
1 parent 0012473 commit 65030e7
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 52 deletions.
98 changes: 50 additions & 48 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ epseltype(x) = eps(float(eltype(x)))
Create an instance of the RNG most appropriate for `x`.
The current defaults are:
- `x isa AbstractArray`
- Julia version is < 1.7: `Random.GLOBAL_RNG`
- Julia version is < 1.7: `rng_from_array()`
- Julia version is >= 1.7: `Random.default_rng()`
- `x isa CuArray`: `CUDA.default_rng()`
When `x` is unspecified, it is assumed to be a `AbstractArray`.
Expand All @@ -49,7 +49,7 @@ rng_from_array(::CuArray) = CUDA.default_rng()
if VERSION >= v"1.7"
rng_from_array() = Random.default_rng()
else
rng_from_array() = Random.GLOBAL_RNG
rng_from_array() = rng_from_array()
end

"""
Expand Down Expand Up @@ -81,7 +81,7 @@ julia> Flux.glorot_uniform(2, 3)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_uniform(rng::AbstractRNG, dims...) = (rand(rng, Float32, dims...) .- 0.5f0) .* sqrt(24.0f0 / sum(nfan(dims...)))
glorot_uniform(dims...) = glorot_uniform(Random.GLOBAL_RNG, dims...)
glorot_uniform(dims...) = glorot_uniform(rng_from_array(), dims...)
glorot_uniform(rng::AbstractRNG) = (dims...) -> glorot_uniform(rng, dims...)

"""
Expand Down Expand Up @@ -114,7 +114,7 @@ julia> Flux.glorot_normal(3, 2)
[1] Glorot, Xavier, and Yoshua Bengio. "Understanding the difficulty of training deep feedforward neural networks." _Proceedings of the thirteenth international conference on artificial intelligence and statistics_. 2010.
"""
glorot_normal(rng::AbstractRNG, dims...) = randn(rng, Float32, dims...) .* sqrt(2.0f0 / sum(nfan(dims...)))
glorot_normal(dims...) = glorot_normal(Random.GLOBAL_RNG, dims...)
glorot_normal(dims...) = glorot_normal(rng_from_array(), dims...)
glorot_normal(rng::AbstractRNG) = (dims...) -> glorot_normal(rng, dims...)

"""
Expand Down Expand Up @@ -151,7 +151,7 @@ function kaiming_uniform(rng::AbstractRNG, dims...; gain = √2)
return (rand(rng, Float32, dims...) .- 0.5f0) .* 2bound
end

kaiming_uniform(dims...; kwargs...) = kaiming_uniform(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_uniform(dims...; kwargs...) = kaiming_uniform(rng_from_array(), dims...; kwargs...)
kaiming_uniform(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_uniform(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -188,9 +188,50 @@ function kaiming_normal(rng::AbstractRNG, dims...; gain = √2f0)
return randn(rng, Float32, dims...) .* std
end

kaiming_normal(dims...; kwargs...) = kaiming_normal(Random.GLOBAL_RNG, dims...; kwargs...)
kaiming_normal(dims...; kwargs...) = kaiming_normal(rng_from_array(), dims...; kwargs...)
kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaiming_normal(rng, dims...; init_kwargs..., kwargs...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.)
Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution.
The values are generated by using a truncated uniform distribution and then using the inverse CDF
for the normal distribution. The method used for generating the random values works best when
`lo ≤ mean ≤ hi`.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.truncated_normal(3, 2)
3×2 Matrix{Float32}:
-0.0340547 -1.35207
-0.22757 -0.793773
-1.75771 1.01801
```
# References
[1] Burkardt, John. "The Truncated Normal Distribution"
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1
end
l = norm_cdf((lo - mean) / std)
u = norm_cdf((hi - mean) / std)
xs = rand(rng, Float32, dims...)
broadcast!(xs, xs) do x
x = x * 2(u - l) + (2l - 1)
x = erfinv(x)
x = clamp.(x * std * 2f0 + mean, lo, hi)
end
return xs
end

truncated_normal(dims...; kwargs...) = truncated_normal(rng_from_array(), dims...; kwargs...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

"""
orthogonal([rng=GLOBAL_RNG], dims...; gain = 1)
Expand Down Expand Up @@ -255,7 +296,7 @@ function orthogonal(rng::AbstractRNG, d1::Integer, ds::Integer...; kwargs...)
return reshape(orthogonal(rng, rows, cols; kwargs...), dims)
end

orthogonal(dims::Integer...; kwargs...) = orthogonal(Random.GLOBAL_RNG, dims...; kwargs...)
orthogonal(dims::Integer...; kwargs...) = orthogonal(rng_from_array(), dims...; kwargs...)
orthogonal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> orthogonal(rng, dims...; init_kwargs..., kwargs...)

"""
Expand Down Expand Up @@ -299,48 +340,9 @@ function sparse_init(rng::AbstractRNG, dims...; sparsity, std = 0.01)
return mapslices(shuffle, sparse_array, dims=1)
end

sparse_init(dims...; kwargs...) = sparse_init(Random.GLOBAL_RNG, dims...; kwargs...)
sparse_init(dims...; kwargs...) = sparse_init(rng_from_array(), dims...; kwargs...)
sparse_init(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> sparse_init(rng, dims...; init_kwargs..., kwargs...)

"""
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2., hi = 2.)
Return an `Array` of size `dims` where each element is drawn from a truncated normal distribution.
The values are generated by using a truncated uniform distribution and then using the inverse CDF
for the normal distribution. The method used for generating the random values works best when
`lo ≤ mean ≤ hi`.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.truncated_normal(3, 2)
3×2 Matrix{Float32}:
-0.113785 -0.627307
-0.676033 0.198423
0.509005 -0.554339
```
# References
[1] Burkardt, John. "The Truncated Normal Distribution"
[PDF](https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf).
Department of Scientific Computing website.
"""
function truncated_normal(rng::AbstractRNG, dims...; mean = 0, std = 1, lo = -2, hi = 2)
norm_cdf(x) = 0.5 * (1 + erf(x/√2))
if (mean < lo - 2 * std) || (mean > hi + 2 * std)
@warn "Mean is more than 2 std from [a, b] in truncated_normal. The distribution of values may be incorrect." maxlog=1
end
l = norm_cdf((lo - mean) / std)
u = norm_cdf((hi - mean) / std)
x = rand(rng, dims...) * 2(u - l) .+ (2l - 1)
x = erfinv.(x)
x = f32(x .* std * 2 .+ mean)
return x
end

truncated_normal(dims::Integer...; kwargs...) = truncated_normal(Random.GLOBAL_RNG, dims...; kwargs...)
truncated_normal(dims) = truncated_normal(Random.GLOBAL_RNG, dims...)
truncated_normal(rng::AbstractRNG; init_kwargs...) = (dims::Integer...; kwargs...) -> truncated_normal(rng, dims...; init_kwargs..., kwargs...)

"""
identity_init([rng=GLOBAL_RNG], dims...; gain=1, shift=0)
Expand Down Expand Up @@ -422,7 +424,7 @@ function identity_init(dims...; gain=1, shift=0)
end

identity_init(::AbstractRNG, dims...; kwargs...) = identity_init(dims...; kwargs...)
identity_init(; init_kwargs...) = identity_init(Random.GLOBAL_RNG; init_kwargs...)
identity_init(; init_kwargs...) = identity_init(rng_from_array(); init_kwargs...)
identity_init(rng::AbstractRNG; init_kwargs...) = (args...;kwargs...) -> identity_init(rng, args...; init_kwargs..., kwargs...)

ones32(dims...) = Base.ones(Float32, dims...)
Expand Down
14 changes: 10 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ using Flux
using Flux: throttle, nfan, glorot_uniform, glorot_normal,
kaiming_normal, kaiming_uniform, orthogonal, truncated_normal,
sparse_init, stack, unstack, Zeros, batch, unbatch,
unsqueeze
unsqueeze, params
using StatsBase: var, std
using Statistics, LinearAlgebra
using Random
using Test

Expand Down Expand Up @@ -149,10 +150,15 @@ end
@testset "truncated_normal" begin
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-2)
@test isapprox(minimum(v), lo; atol = 1f-2)
@test isapprox(maximum(v), hi; atol = 1f-2)
@test eltype(v) == Float32
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1e-2)
@test isapprox(minimum(v), lo; atol = 1e-2)
@test isapprox(maximum(v), hi; atol = 1e-2)
@test isapprox(std(v), σ; atol = 1f-2)
@test eltype(v) == Float32
end
end
Expand Down

0 comments on commit 65030e7

Please sign in to comment.