Skip to content

Commit

Permalink
Increased tolerance, modified doctests
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Feb 19, 2022
1 parent f030521 commit 9032c8e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
17 changes: 10 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,20 +195,23 @@ kaiming_normal(rng::AbstractRNG; init_kwargs...) = (dims...; kwargs...) -> kaimi
truncated_normal([rng=GLOBAL_RNG], dims...; mean = 0, std = 1, lo = -2, hi = 2)
Return an `Array{Float32}` of size `dims` where each element is drawn from a truncated normal distribution.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(10^6))`.
The numbers are distributed like `filter(x -> lo<=x<=hi, mean .+ std .* rand(dims...))`.
The values are generated by sampling a Uniform(0, 1) (`rand()`) and then
applying the inverse CDF of the truncated normal distribution
(see the references for more info).
This method works best when `lo ≤ mean ≤ hi`.
# Examples
```jldoctest; setup = :(using Random; Random.seed!(0))
julia> Flux.truncated_normal(3, 2)
3×2 Matrix{Float32}:
-0.0340547 -1.35207
-0.22757 -0.793773
-1.75771 1.01801
```jldoctest
julia> Flux.truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(Flux.truncated_normal(10^6)); digits=3)
(-2.0f0, 2.0f0)
julia> round(std(Flux.truncated_normal(10^6; lo = -100, hi = 100)))
1.0f0
```
# References
Expand Down
8 changes: 4 additions & 4 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,14 +151,14 @@ end
size = (100, 100, 100)
for (μ, σ, lo, hi) in [(0., 1, -2, 2), (0, 1, -4., 4)]
v = truncated_normal(size; mean = μ, std = σ, lo, hi)
@test isapprox(mean(v), μ; atol = 1f-2)
@test isapprox(minimum(v), lo; atol = 1f-2)
@test isapprox(maximum(v), hi; atol = 1f-2)
@test isapprox(mean(v), μ; atol = 1f-1)
@test isapprox(minimum(v), lo; atol = 1f-1)
@test isapprox(maximum(v), hi; atol = 1f-1)
@test eltype(v) == Float32
end
for (μ, σ, lo, hi) in [(6, 2, -100., 100), (7., 10, -100, 100)]
v = truncated_normal(size...; mean = μ, std = σ, lo, hi)
@test isapprox(std(v), σ; atol = 1f-2)
@test isapprox(std(v), σ; atol = 1f-1)
@test eltype(v) == Float32
end
end
Expand Down

0 comments on commit 9032c8e

Please sign in to comment.