-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Truncated normal initialisation for weights #1877
Conversation
4534056
to
65164a8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't seem to be doing the resampling strategy like the docstring suggests. Instead, it is sampling from a uniform distribution and applying the inverse CDF for the truncated normal.
In theory, the clamp
should only be necessary for floating point errors. Can we remove the clamp
and see how often it fails (if at all)?
src/utils.jl
Outdated
julia> Flux.truncated_normal(3, 2) | ||
3×2 Matrix{Float32}: | ||
-0.113785 -0.627307 | ||
-0.676033 0.198423 | ||
0.509005 -0.554339 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How much does anyone learn from this? Maybe something like:
julia> truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"
julia> round.(extrema(truncated_normal(10^6)); digits=3)
(-1.0f0, 1.0f0)
julia> using Plots
julia> histogram(truncated_normal(10^6; σ=2, a=-4, b=4); bins=-4:0.1:4, alpha=0.5)
julia> histogram!(randn(10^6); bins=-4:0.1:4, alpha=0.5)
Or even nicer, something inline with UnicodePlots.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UnicodePlots has a histogram
function. Using that would be nice here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll whip something up. Is this something that should potentially extend to all the init layers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
julia> histogram(reduce(vcat, truncated_normal(100, 100)))
┌ ┐
[-2.0, -1.5) ┤███████▍ 448
[-1.5, -1.0) ┤████████████████▎ 982
[-1.0, -0.5) ┤█████████████████████████▍ 1538
[-0.5, 0.0) ┤████████████████████████████████▌ 1974
[ 0.0, 0.5) ┤██████████████████████████████████ 2066
[ 0.5, 1.0) ┤█████████████████████████▋ 1556
[ 1.0, 1.5) ┤███████████████▋ 952
[ 1.5, 2.0) ┤███████▉ 484
└ ┘
Frequency
This looks pretty nice to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reduce(vcat
is just vec
, but perhaps just make a vector to start with?
And maybe show something like truncated_normal(10^4, lo=0, hi=4)
where the truncation is really obvious?
In NNlib, the graphs are currently not jldoctests
, since documenter got confused as to whether they were new code or not. Not testing this here is also one way not to have random number problems.
I'd include the line julia> using UnicodePlots
since I think docstrings should load any packages besides Flux needed to reproduce them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works. Given that this is something that needs a little more hashing out (as noted here), I think I'll try and open up a separate PR addressing these concerns for all the init strategies
Yep, started working on one approach and ended up with another because I found it to be more common 🤦🏽♂️ I'll fix the docstring
I was wondering about this too. Mathematically it's well within the bounds unless possibly some absolute sorcery is unleashed, but the PyTorch master version had it and I was worried I would miss out on some edge cases |
It seems OK to talk about resampling as an explanation of what the truncated normal distribution is. I guess it should't imply that this is how the implementation works, but what algorithm |
Okay the API and the docstring should be fine now. I'll come up with more specific cases for tests and then push this through for another round of reviews😅 |
There are some review comments from my previous review related to the All the constants and
|
My worry is that we silently sample the wrong distribution since the support of the samples will appear correct. So if we keep it in, I'd like to know that it's only for numerical reasons first (by seeing the resulting histogram without it). |
Oh, sorry about that. As I noted here, I did switch to Float32 initially but the error was quite noticeable and so I decided to switch back to calculating the result as a Float64 and converting it in the end. I'll write back to the same array and see if that works out |
There's a NaN obtained with some values without the clamp - so I think it's best that it's probably left in? ( for reference, values are [μ, σ, lo, hi] = [7., 10, -100, 100]) |
Writing back to the array seems to be working fine - the error is under control too for now, so I think it should be okay. I've also changed the RNGs for the init layers to reflect this |
Fix tests gaffe
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
b42d018
to
65030e7
Compare
65030e7
to
46feb85
Compare
Welp. CI shows there's an error because 3.98 is calculated instead of 4 (tolerance for now is 0.01). Should I increase the tolerance? And also, doctests keep failing - I'm not sure why, but the values it gets seem to be different from mine. I'm on Julia 1.8 on a Mac if that matters |
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
I don't think random numbers are stable between Julia versions. They may only be tested on one version, but better, most of the doctests are written never to print them, always show size of |
Should I fix this for the other init layers as well? And change them to have plots too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of changes, but it otherwise looks good to me.
No, I think we can do that in a separate PR so that we can merge this one. Even for plots, we will want to use StableRNGs in our examples to make sure very minor changes don't happen to the histograms. For this PR, I think a summary is fine, and maybe some kind of |
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
9032c8e
to
21ec5f4
Compare
Multiple modifications to doctrings
21ec5f4
to
8403e31
Compare
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Codecov Report
@@ Coverage Diff @@
## master #1877 +/- ##
==========================================
- Coverage 86.03% 86.03% -0.01%
==========================================
Files 19 19
Lines 1411 1425 +14
==========================================
+ Hits 1214 1226 +12
- Misses 197 199 +2
Continue to review full report at Codecov.
|
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
This PR adds a truncated normalisation method for the initialisation of weights in layers. This has been used in recent papers including ConvNeXt (work is currently ongoing to implement a Julia version of the same at FluxML/Metalhead.jl#119) and is also incorporated in PyTorch master. Unfortunately there's an additional dep in the form of SpecialFunctions.jl, but thankfully it seems to be quite lightweight
PR Checklist