-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add RNG support for Dropout/AlphaDropout #1849
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See if these help with the tests. One outstanding question is how task local RNGs factor into this, thoughts from the floor welcome.
I was under the impression that |
Do you mean |
Probably the types for dispatch like |
The big headache is that neither of Base's default RNGs (global and task local) use an algorithm available in CUDA and vice versa. Random123.jl does have a Philox implementation, but expecting users to depend on it just to make their models GPU compatible is a stretch. The implication is that converting RNGs between devices completely destroys the internal state, whereas with parameters it keeps as much semantically meaningful internal state as possible. @ablaom if I understand FluxML/MLJFlux.jl#166 correctly, you only need the ability to specify a consistent RNG for the whole model and not individual layers, right? One way to get around the challenges of storing the RNG in the layer would be to use Cassette-style contextual dispatch to hook |
Just for clarity, are you suggesting (or some equivalent): function dropout_mask(x, p; dims=:)
y = rand!(rng_from_array(x), similar(x, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end |
Yes. |
Currently, CUDA.jl only exposes a single RNG. I think it is reasonable that Let's say the low level contextual approach is like Flux.rng_from_array(x) = MersenneTwister(123)
m = Dropout(0.1)
m(rand(Float32, 10)) # works, uses MersenneTwister
m(CUDA.rand(10)) # errors due to Flux.rng_from_array(::CuArray) inside Flux.jl Then you can have block like m = Dropout(0.1)
withrng(MersenneTwister(123)) do
m(rand(Float32, 10))
end But I think this becomes trickier for packages that want the RNG on construction instead of the forward pass. Of course, if you have a wrapper around the model, then the wrapper forward pass could use |
Kyle beat me to it, but thoughts below.
I was thinking a couple layers higher. The addition of dropout(x, p; kwargs...) = dropout(Random.default_rng(), x, p; kwargs...)
dropout(x::CuArray, p; kwargs...) = dropout(CUDA.CURAND.default_rng(), x, p; kwargs...) to this: rng_from_array(_) = Random.default_rng()
rng_from_array(::CuArray) = CUDA.default_rng()
dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...) Likewise for using Cassette
Cassette.@context RNGOverrideCtx
Cassette.overdub(ctx::RNGOverrideCtx, ::typeof(Flux.rng_from_array), _) = rng.metadata
with_rng_override(f, rng) = Cassette.overdub(RNGOverrideCtx(metadata=rng), f)
...
model, x, rng = ...
with_rng_override(rng) do
model(x)
end |
To clarify, are you referring to any instance of the types those functions return or to the returned instances themselves? The latter would be ideal, but it's not clear to me whether the return value of |
That implementation makes sense to me, and it would be safer. One thing we could do to the current implementation is dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) = ArgumentError("dropout_mask only supports CUDA.RNG when x isa CuArray.")
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
y = rand!(rng, similar(x, _dropout_shape(x, dims)))
y .= _dropout_kernel.(y, p, 1 - p)
return y
end I think that will raise both approaches to the same level of safety without considering task-related issues.
The latter, I think though I am similarly unsure about the task system and the interaction with the RNGs. |
Okay, now only any Trying to move |
Seems like |
@test 40 < sum(evalwgrad(m, x)) < 130 | ||
else | ||
# FIXME: this breaks spuriously for MersenneTwister | ||
@test_skip 40 < sum(evalwgrad(m, x)) < 130 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems reasonable to me, opened #1851 for a proper test. Anything left assuming tests pass?
It looks like we'll need a |
😱indeed. I thought the current code would be generic enough to handle 1.6, but it's a good enough opportunity as any to introduce the cleaner API. |
Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com> Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
Skip AlphaDropout tests for custom RNG.
556b8bc
to
2385358
Compare
8ffba95
to
15ac6cd
Compare
Codecov Report
@@ Coverage Diff @@
## master #1849 +/- ##
==========================================
- Coverage 73.88% 73.85% -0.04%
==========================================
Files 28 28
Lines 1662 1683 +21
==========================================
+ Hits 1228 1243 +15
- Misses 434 440 +6
Continue to review full report at Codecov.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Adds support for custom RNGs to
Dropout
anddropout
. For dealing with the GPU, the RNG field is mapped to the corresponding default RNG for CUDA.jl when it isRandom.default_rng()
. All other RNGs will throw an error (as these are not supported by CUDA.jl). Still needs more tests.Fixes #1617.
PR Checklist