diff --git a/src/utils.jl b/src/utils.jl index 35eec9705b..1e8f1e2270 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -195,7 +195,7 @@ kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaimi truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2) Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution. -The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(10^6))`. +The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(dims...))`. The values are generated by sampling a Uniform(0, 1) (`rand()`) and then applying the inverse CDF of the truncated normal distribution @@ -203,12 +203,15 @@ applying the inverse CDF of the truncated normal distribution This method works best when `lo ≤ mean ≤ hi`. # Examples -```jldoctest; setup = :(using Random; Random.seed!(0)) -julia> Flux.truncated_normal(3, 2) -3×2 Matrix{Float32}: - -0.0340547 -1.35207 - -0.22757 -0.793773 - -1.75771 1.01801 +```jldoctest; setup = :(using Statistics) +julia> Flux.truncated_normal(3, 4) |> summary +"3×4 Matrix{Float32}" + +julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3) +(-2.0f0, 2.0f0) + +julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100))) +1.0f0 ``` # References diff --git a/test/utils.jl b/test/utils.jl index 4481d7f607..3f4efe24ba 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -151,14 +151,14 @@ end size = (100, 100, 100) for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)] v = truncated_normal(size; mean = μ, std = σ, lo, hi) - @test isapprox(mean(v), μ; atol = 1f-2) - @test isapprox(minimum(v), lo; atol = 1f-2) - @test isapprox(maximum(v), hi; atol = 1f-2) + @test isapprox(mean(v), μ; atol = 1f-1) + @test isapprox(minimum(v), lo; atol = 1f-1) + @test isapprox(maximum(v), hi; atol = 1f-1) @test eltype(v) == Float32 end for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)] v = truncated_normal(size...; mean = μ, std = σ, lo, hi) - @test isapprox(std(v), σ; atol = 1f-2) + @test isapprox(std(v), σ; atol = 1f-1) @test eltype(v) == Float32 end end