Skip to content

Conversation

@darsnack
Copy link
Member

@darsnack darsnack commented Jan 25, 2022

Adds support for custom RNGs to Dropout and dropout. For dealing with the GPU, the RNG field is mapped to the corresponding default RNG for CUDA.jl when it is Random.default_rng(). All other RNGs will throw an error (as these are not supported by CUDA.jl). Still needs more tests.

Fixes #1617.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable
  • API changes require approval from a committer (different from the author, if applicable)

@darsnack darsnack mentioned this pull request Jan 25, 2022
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See if these help with the tests. One outstanding question is how task local RNGs factor into this, thoughts from the floor welcome.

@darsnack
Copy link
Member Author

I was under the impression that default_rng was task local.

@ToucheSir
Copy link
Member

Do you mean CUDA.default_rng? I believe it's used for both the global RNG and task local ones.

@darsnack
Copy link
Member Author

Random.default_rng() will expand to default_rng(Threads.threadid()). Not sure about CUDA RNGs. Since using a CUDA RNG involves calling gpu and that must come after CUDA.device!(...), it's possible it works depending on CUDA.default_rng.

Probably the types for dispatch like Random._GLOBAL_RNG and CUDA.RNG are not the correct things to use.

@ToucheSir
Copy link
Member

ToucheSir commented Jan 26, 2022

The big headache is that neither of Base's default RNGs (global and task local) use an algorithm available in CUDA and vice versa. Random123.jl does have a Philox implementation, but expecting users to depend on it just to make their models GPU compatible is a stretch. The implication is that converting RNGs between devices completely destroys the internal state, whereas with parameters it keeps as much semantically meaningful internal state as possible.

@ablaom if I understand FluxML/MLJFlux.jl#166 correctly, you only need the ability to specify a consistent RNG for the whole model and not individual layers, right? One way to get around the challenges of storing the RNG in the layer would be to use Cassette-style contextual dispatch to hook default_rng. Then whenever Dropout calls default_rng(), it'll receive the user-specified RNG instead of the global or task local one. A refinement of this idea would be to define a higher-level function Flux/NNlib.rng_from_array(::AbstractArray) which Dropout calls. This can be overloaded for CuArrays and would also be safer to use contextual dispatch on. There are two questions with this approach: can Zygote handle it, and what if any new deps would Flux need to do it?

@darsnack
Copy link
Member Author

darsnack commented Jan 26, 2022

Just for clarity, are you suggesting (or some equivalent):

function dropout_mask(x, p; dims=:)
  y = rand!(rng_from_array(x), similar(x, _dropout_shape(x, dims)))
  y .= _dropout_kernel.(y, p, 1 - p)
  return y
end

@ablaom
Copy link

ablaom commented Jan 26, 2022

you only need the ability to specify a consistent RNG for the whole model and not individual layers, right?

Yes.

@darsnack
Copy link
Member Author

darsnack commented Jan 26, 2022

Currently, CUDA.jl only exposes a single RNG. I think it is reasonable that Random.default_rng and CUDA.default_rng can be swapped back and forth, since these are the default cases AKA the cases where users aren't specifying an RNG at all. Currently, this PR conservatively throws an error for any other type of RNG, since they only exist on the CPU side of things. We could also warn when mapping Random.default_rng to the GPU and vice versa.

Let's say the low level contextual approach is like

Flux.rng_from_array(x) = MersenneTwister(123)
m = Dropout(0.1)
m(rand(Float32, 10)) # works, uses MersenneTwister
m(CUDA.rand(10)) # errors due to Flux.rng_from_array(::CuArray) inside Flux.jl

Then you can have block like

m = Dropout(0.1)
withrng(MersenneTwister(123)) do
  m(rand(Float32, 10))
end

But I think this becomes trickier for packages that want the RNG on construction instead of the forward pass. Of course, if you have a wrapper around the model, then the wrapper forward pass could use withrng, but I don't know if that will always be the case.

@ToucheSir
Copy link
Member

Kyle beat me to it, but thoughts below.

Just for clarity, are you suggesting (or some equivalent):
...

I was thinking a couple layers higher. The addition of dropout(rng, x) in this PR make a lot of sense, so the change would be from this:

dropout(x, p; kwargs...) = dropout(Random.default_rng(), x, p; kwargs...)
dropout(x::CuArray, p; kwargs...) = dropout(CUDA.CURAND.default_rng(), x, p; kwargs...)

to this:

rng_from_array(_) = Random.default_rng()
rng_from_array(::CuArray) = CUDA.default_rng()

dropout(x, p; kwargs...) = dropout(rng_from_array(x), x, p; kwargs...)

Likewise for (d::Dropout)(x). Then the override would look like:

using Cassette

Cassette.@context RNGOverrideCtx
Cassette.overdub(ctx::RNGOverrideCtx, ::typeof(Flux.rng_from_array), _) = rng.metadata

with_rng_override(f, rng) = Cassette.overdub(RNGOverrideCtx(metadata=rng), f)

...
model, x, rng = ...
with_rng_override(rng) do
  model(x)
end

@ToucheSir
Copy link
Member

I think it is reasonable that Random.default_rng and CUDA.default_rng can be swapped back and forth, since these are the default cases AKA the cases where users aren't specifying an RNG at all.

To clarify, are you referring to any instance of the types those functions return or to the returned instances themselves? The latter would be ideal, but it's not clear to me whether the return value of default_rng is consistent (that is, ===) over time. I also don't understand enough about the task system to know whether moving a model constructed with the global RNG to a task would affect performance down the line, as IIRC the global RNG requires locking to be thread-safe.

@darsnack
Copy link
Member Author

That implementation makes sense to me, and it would be safer. One thing we could do to the current implementation is

dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
dropout_mask(rng, x::CuArray, p; kwargs...) = ArgumentError("dropout_mask only supports CUDA.RNG when x isa CuArray.")
dropout_mask(rng, x, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function _dropout_mask(rng, x, p; dims=:)
  y = rand!(rng, similar(x, _dropout_shape(x, dims)))
  y .= _dropout_kernel.(y, p, 1 - p)
  return y
end

I think that will raise both approaches to the same level of safety without considering task-related issues.

To clarify, are you referring to any instance of the types those functions return or to the returned instances themselves?

The latter, I think though I am similarly unsure about the task system and the interaction with the RNGs.

@darsnack
Copy link
Member Author

darsnack commented Jan 26, 2022

Okay, now only any TaskLocalRNG and CUDA.RNG are considered "swappable" though the state might be destroyed.

Trying to move Random.GLOBAL_RNG will work for Julia < 1.7 right now, and we could add a warning too. Presumably, you wouldn't be using this with tasks, so you'd be happy to accept CUDA.RNG.

@darsnack
Copy link
Member Author

Seems like AlphaDropout is extremely sensitive to the RNG. Do we still want to include it in this PR?

@test 40 < sum(evalwgrad(m, x)) < 130
else
# FIXME: this breaks spuriously for MersenneTwister
@test_skip 40 < sum(evalwgrad(m, x)) < 130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable to me, opened #1851 for a proper test. Anything left assuming tests pass?

@darsnack
Copy link
Member Author

Screen Shot 2022-01-27 at 11 15 06 AM

😱

@darsnack
Copy link
Member Author

It looks like we'll need a rng_from_array anyways to avoid making the codebase littered with if VERSION ...

@ToucheSir
Copy link
Member

😱indeed. I thought the current code would be generic enough to handle 1.6, but it's a good enough opportunity as any to introduce the cleaner API.

@darsnack darsnack force-pushed the darsnack/dropout-rng branch from 556b8bc to 2385358 Compare January 27, 2022 18:05
@darsnack darsnack changed the title Add RNG support for Dropout Add RNG support for Dropout/AlphaDropout Jan 27, 2022
@darsnack darsnack force-pushed the darsnack/dropout-rng branch from 8ffba95 to 15ac6cd Compare January 27, 2022 21:40
@darsnack darsnack requested a review from ToucheSir January 27, 2022 22:04
@codecov-commenter
Copy link

Codecov Report

Merging #1849 (15ac6cd) into master (ef04fda) will decrease coverage by 0.03%.
The diff coverage is 80.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1849      +/-   ##
==========================================
- Coverage   73.88%   73.85%   -0.04%     
==========================================
  Files          28       28              
  Lines        1662     1683      +21     
==========================================
+ Hits         1228     1243      +15     
- Misses        434      440       +6     
Impacted Files Coverage Δ
src/utils.jl 87.34% <50.00%> (-0.93%) ⬇️
src/functor.jl 86.66% <77.77%> (-1.80%) ⬇️
src/layers/normalise.jl 82.38% <86.36%> (-0.28%) ⬇️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update ef04fda...15ac6cd. Read the comment docs.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@darsnack darsnack merged commit 8d3b8d3 into FluxML:master Jan 27, 2022
@darsnack darsnack deleted the darsnack/dropout-rng branch January 27, 2022 22:10
@mcabbott mcabbott mentioned this pull request Jan 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Allow specification of RNG in Dropout

5 participants