Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Truncated normal initialisation for weights #1877

Merged
merged 12 commits into from
Feb 19, 2022

Conversation

theabhirath
Copy link
Member

@theabhirath theabhirath commented Feb 17, 2022

This PR adds a truncated normalisation method for the initialisation of weights in layers. This has been used in recent papers including ConvNeXt (work is currently ongoing to implement a Julia version of the same at FluxML/Metalhead.jl#119) and is also incorporated in PyTorch master. Unfortunately there's an additional dep in the form of SpecialFunctions.jl, but thankfully it seems to be quite lightweight

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@theabhirath theabhirath changed the title Trunc normal Truncated normal initialisation for weights Feb 17, 2022
@theabhirath theabhirath force-pushed the trunc-normal branch 2 times, most recently from 4534056 to 65164a8 Compare February 17, 2022 14:12
Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't seem to be doing the resampling strategy like the docstring suggests. Instead, it is sampling from a uniform distribution and applying the inverse CDF for the truncated normal.

In theory, the clamp should only be necessary for floating point errors. Can we remove the clamp and see how often it fails (if at all)?

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated
Comment on lines 315 to 319
julia> Flux.truncated_normal(3, 2)
3×2 Matrix{Float32}:
-0.113785 -0.627307
-0.676033 0.198423
0.509005 -0.554339
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does anyone learn from this? Maybe something like:

julia> truncated_normal(3, 4) |> summary
"3×4 Matrix{Float32}"

julia> round.(extrema(truncated_normal(10^6)); digits=3)
(-1.0f0, 1.0f0)

julia> using Plots

julia> histogram(truncated_normal(10^6; σ=2, a=-4, b=4); bins=-4:0.1:4, alpha=0.5)

julia> histogram!(randn(10^6); bins=-4:0.1:4, alpha=0.5)

Or even nicer, something inline with UnicodePlots.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnicodePlots has a histogram function. Using that would be nice here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll whip something up. Is this something that should potentially extend to all the init layers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

julia> histogram(reduce(vcat, truncated_normal(100, 100)))
                ┌                                        ┐ 
   [-2.0, -1.5) ┤███████▍ 448                              
   [-1.5, -1.0) ┤████████████████▎ 982                     
   [-1.0, -0.5) ┤█████████████████████████▍ 1538           
   [-0.5,  0.0) ┤████████████████████████████████▌ 1974    
   [ 0.0,  0.5) ┤██████████████████████████████████  2066  
   [ 0.5,  1.0) ┤█████████████████████████▋ 1556           
   [ 1.0,  1.5) ┤███████████████▋ 952                      
   [ 1.5,  2.0) ┤███████▉ 484                              
                └                                        ┘ 
                                 Frequency            

This looks pretty nice to me

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reduce(vcat is just vec, but perhaps just make a vector to start with?

And maybe show something like truncated_normal(10^4, lo=0, hi=4) where the truncation is really obvious?

In NNlib, the graphs are currently not jldoctests, since documenter got confused as to whether they were new code or not. Not testing this here is also one way not to have random number problems.

I'd include the line julia> using UnicodePlots since I think docstrings should load any packages besides Flux needed to reproduce them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works. Given that this is something that needs a little more hashing out (as noted here), I think I'll try and open up a separate PR addressing these concerns for all the init strategies

test/utils.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

This doesn't seem to be doing the resampling strategy like the docstring suggests. Instead, it is sampling from a uniform distribution and applying the inverse CDF for the truncated normal.

Yep, started working on one approach and ended up with another because I found it to be more common 🤦🏽‍♂️ I'll fix the docstring

In theory, the clamp should only be necessary for floating point errors. Can we remove the clamp and see how often it fails (if at all)?

I was wondering about this too. Mathematically it's well within the bounds unless possibly some absolute sorcery is unleashed, but the PyTorch master version had it and I was worried I would miss out on some edge cases

@mcabbott
Copy link
Member

It seems OK to talk about resampling as an explanation of what the truncated normal distribution is. I guess it should't imply that this is how the implementation works, but what algorithm randn uses is also not something we need to know to call it.

@theabhirath
Copy link
Member Author

Okay the API and the docstring should be fine now. I'll come up with more specific cases for tests and then push this through for another round of reviews😅

@darsnack
Copy link
Member

darsnack commented Feb 17, 2022

There are some review comments from my previous review related to the Float32 issue that don't seem resolved? I didn't see a decision to use Float64, so I have unmarked those comments as resolved. Do you mind taking a look again?

All the constants and rand should be Float32 when they are created instead of converting at the end (like discussed above). This does keep things open to someone setting the keyword arguments to Float64 and getting a Float64 result. But this is a standard application of Julia's type promotion system, so I personally don't think we should forbid it. I see that other initialization functions are inconsistent about this behavior. Perhaps we should settle in this PR what kind of behavior we want.

Even if the decision is to always return a Float32, it is cheaper to convert the scalar output of norm_cdf to Float32 than to convert the array at the end. Or do Michael's suggestion above about writing back to the same array which is cleaner.

@darsnack
Copy link
Member

I was wondering about this too. Mathematically it's well within the bounds unless possibly some absolute sorcery is unleashed, but the PyTorch master version had it and I was worried I would miss out on some edge cases

My worry is that we silently sample the wrong distribution since the support of the samples will appear correct. So if we keep it in, I'd like to know that it's only for numerical reasons first (by seeing the resulting histogram without it).

@theabhirath
Copy link
Member Author

There are some review comments from my previous review related to the Float32 issue that don't seem resolved?

Oh, sorry about that. As I noted here, I did switch to Float32 initially but the error was quite noticeable and so I decided to switch back to calculating the result as a Float64 and converting it in the end. I'll write back to the same array and see if that works out

@theabhirath
Copy link
Member Author

theabhirath commented Feb 18, 2022

My worry is that we silently sample the wrong distribution since the support of the samples will appear correct. So if we keep it in, I'd like to know that it's only for numerical reasons first (by seeing the resulting histogram without it).

There's a NaN obtained with some values without the clamp - so I think it's best that it's probably left in? ( for reference, values are [μ, σ, lo, hi] = [7., 10, -100, 100])

@theabhirath
Copy link
Member Author

Writing back to the array seems to be working fine - the error is under control too for now, so I think it should be okay. I've also changed the RNGs for the init layers to reflect this

theabhirath and others added 4 commits February 18, 2022 08:16
Fix tests gaffe
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@theabhirath
Copy link
Member Author

the error is under control too for now

Welp. CI shows there's an error because 3.98 is calculated instead of 4 (tolerance for now is 0.01). Should I increase the tolerance? And also, doctests keep failing - I'm not sure why, but the values it gets seem to be different from mine. I'm on Julia 1.8 on a Mac if that matters

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
@mcabbott
Copy link
Member

I don't think random numbers are stable between Julia versions. They may only be tested on one version, but better, most of the doctests are written never to print them, always show size of summary etc instead.

@theabhirath
Copy link
Member Author

I don't think random numbers are stable between Julia versions. They may only be tested on one version, but better, most of the doctests are written never to print them, always show size of summary etc instead.

Should I fix this for the other init layers as well? And change them to have plots too?

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple of changes, but it otherwise looks good to me.

src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
src/utils.jl Outdated Show resolved Hide resolved
@darsnack
Copy link
Member

Should I fix this for the other init layers as well? And change them to have plots too?

No, I think we can do that in a separate PR so that we can merge this one. Even for plots, we will want to use StableRNGs in our examples to make sure very minor changes don't happen to the histograms. For this PR, I think a summary is fine, and maybe some kind of filter example that shows all the values are in the given range.

theabhirath and others added 2 commits February 19, 2022 08:33
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@theabhirath theabhirath force-pushed the trunc-normal branch 3 times, most recently from 9032c8e to 21ec5f4 Compare February 19, 2022 04:00
src/utils.jl Outdated Show resolved Hide resolved
Multiple modifications to doctrings
src/utils.jl Outdated Show resolved Hide resolved
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
@codecov-commenter
Copy link

codecov-commenter commented Feb 19, 2022

Codecov Report

Merging #1877 (423a996) into master (13a65be) will decrease coverage by 0.00%.
The diff coverage is 90.47%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1877      +/-   ##
==========================================
- Coverage   86.03%   86.03%   -0.01%     
==========================================
  Files          19       19              
  Lines        1411     1425      +14     
==========================================
+ Hits         1214     1226      +12     
- Misses        197      199       +2     
Impacted Files Coverage Δ
src/utils.jl 93.46% <90.47%> (-0.79%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 13a65be...423a996. Read the comment docs.

NEWS.md Outdated Show resolved Hide resolved
Co-authored-by: Kyle Daruwalla <daruwalla.k.public@icloud.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants