-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use view for RNN gate slice extraction #1761
Conversation
@ToucheSir I'm happy to run some benchmarks on this change, but I'm out traveling at the moment, so it will be a few days before I can get to that. |
@ToucheSir I ran some test mid-size LSTMs, 2 layers, 64 neurons, 256 batch size. Although modest, the proposal looks to bring some performance benefits: CPU this PR: GPU this PR: I'm therefore all in to adopt this change. |
Ok, my benchmarks are in too. The code, logs and raw data are all available at https://gist.github.com/ToucheSir/2fecbe99e5b304fed11e25e42c535cc3. Because there are so many dimensions to this, I've also created an interactive figure (partially for my own sanity) here for folks to peruse.
Here is the best-performing configuration (LSTM on GPU): And here is the worst (GRU on CPU): Breaking things down further, almost all of the performance gains come from a much more efficient forward pass. This includes a noticeable reduction in both allocation count and memory allocated. Conversely, any performance loss seems to be caused by a lower backwards pass. However, I'm not sure why this PR would be slower (sometimes, not always!) there. The allocations (count and KB) are identical across both versions, which makes sense given they use the exact same pullback. @mcabbott did you ever encounter this while working on getindex or accum PRs for Zygote? |
b79fbb9
to
85e3054
Compare
This was originally passed over in #907, but I don't find that argument particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations.
85e3054
to
c9627c5
Compare
I don't know why The backward is going to allocate a ton, since it'll make a full copy (mostly zero) for every time you take a slice, and another to add each pair of those with The first is harder. One way might be to replace individual calls to |
x = rand(2, 3);
julia> Meta.@lower @view x[:, 2]
:($(Expr(:thunk, CodeInfo(
@ none within `top-level scope'
1 ─ goto #3 if not true
2 ─ %2 = (view)(x, :, 2)
└── return %2
3 ─ return false
))))
f(x) = sum(@view x[:, 1])
julia> @code_adjoint f(rand(3, 4))
Zygote.Adjoint(1: (%3, %4 :: Zygote.Context, %1, %2)
br 3 unless true
2:
%5 = Zygote._pullback(%4, view, %2, Main.:(:), 1)
%6 = Base.getindex(%5, 1)
%7 = Base.getindex(%5, 2)
br 4 (%6, 1)
3:
br 4 (false, 2)
4: (%8, %12 :: UInt8)
%9 = Zygote._pullback(%4, Main.sum, %8)
%10 = Base.getindex(%9, 1)
%11 = Base.getindex(%9, 2)
return %10, 1: (%1)
%2 = @12 !== 0x01
%3 = (@11)(%1)
%4 = Zygote.gradindex(%3, 2)
br 3 unless %2
br 2
2:
br 4 (nothing)
3:
%5 = (@7)(%4)
%6 = Zygote.gradindex(%5, 2)
br 4 (%6)
4: (%7)
%8 = Zygote.tuple(nothing, %7)
return %8) I'm not sure either, perhaps more work for the compiler? Profiling showed that the lion's share of time was spent in the generated
I was thinking about that yesterday as well. This monolithic kernel approach has been successfully done before, but it would require some reorganizing of the RNN cell internals to work. Having one in-place function and corresponding rrule (+ GPU methods) for all but the matmuls in https://github.com/FluxML/Flux.jl/blob/master/src/layers/recurrent.jl#L159-L165 would be very cool. |
If I'm reading correctly, this multigate(x::AbstractArray, h::Int) = ntuple(n -> gate(x,h,n), 4)
@adjoint multigate(x::AbstractArray, h::Int) =
multigate(x, h), dy -> (vcat(dy...), nothing) and then
Agree there's scope for much more fusing. Like |
Do the two matrix multiplications If so, then this is one way you could fuse things: https://github.com/FluxML/NNlib.jl/pull/346/files#diff-0e3febb41064ef9f892dd2b52c708e0497c548275cd543ec3f95cb947fa08b1cR6-R22 |
a286d3d
to
b7df765
Compare
Thanks, I'm running some benchmarks for There are actually even more opportunities for fusion in most RNN cells. For the LSTM and GRUv1, everything after the first 2 matmuls is pointwise ops, so combining them into one kernel is perfectly possible. That could then be made to operate in-place, with gradients provided by ForwardDiff or Enzyme. |
Here's a quick attempt at fusion for using Zygote, NNlib
prodcast(f::Function, x::AbstractArray, g::Function, y::AbstractArray) = @. f(x) * g(y)
a, b, c, d = (rand(100,100) for _ in 1:4);
@btime gradient(sum∘prodcast, tanh, $a, sigmoid, $b); # 216.792 μs (79 allocations: 549.22 KiB)
@btime gradient(sum∘prodcast, identity, $a, tanh, $b); # 147.500 μs (75 allocations: 392.86 KiB)
@btime copy($a); # 2.932 μs (2 allocations: 78.17 KiB)
import ChainRulesCore: rrule, derivatives_given_output
derivatives_given_output(_, ::typeof(identity), _) = ((true,),)
function rrule(::typeof(prodcast), f, x, g, y)
fx = f==identity ? x : f.(x) # will need this to avoid re-computing tanh in the gradient
gy = g==identity ? y : g.(y)
size(x) == size(y) || throw("sizes must match")
function uncast(Δraw)
Δ = unthunk(Δraw)
dx = first.(first.(derivatives_given_output.(fx, f, x))) .* gy .* Δ
dy = if g==identity
# first.(first.(derivatives_given_output.(gy, g, y))) .* fx .* Δ
fx .* Δ
else
# we are free to overwrite gy, although should really check eltypes
gy .= first.(first.(derivatives_given_output.(gy, g, y))) .* fx .* Δ
end
(NoTangent(), NoTangent(), dx, NoTangent(), dy)
end
fx .* gy, uncast
end
Zygote.refresh()
@btime gradient(sum∘prodcast, tanh, $a, sigmoid, $b); # 213.916 μs (83 allocations: 314.94 KiB) # -3 copies
@btime gradient(sum∘prodcast, identity, $a, tanh, $b); # 149.625 μs (79 allocations: 236.73 KiB) With With ForwardDiff, you could have a fused forward pass but I think you then end up storing an array of tuples, containing the sensitivities, which is more memory than storing If you fused this with |
Ok, benchmarks with ** there were a couple of runs where end-to-end performance on master was abnormally fast, way faster than the sum of forward and backwards passes separately. This machine is not exactly a noise-free benchmarking environment, so I don't think they're significant. |
Good to know, I assumed it could switch to an SoA layout but presumably that is not the case. In fairness, writing out the rrule by hand isn't particularly difficult either. I think the cost-benefit of retaining fused broadcasts is worth the extra complexity, so if there's enough interest I can look into filing a follow-up PR. Longer term it would be great if we could pull in Enzyme + KernelAbstractions for this, but that's a discussion for another thread ;) |
Even with hand-written fusion, my point is that either you run What you could save with SoA might be the number of separate kernel launches, maybe, am not sure. On the CPU this is super-easy with StructArrays in fact. But certainly another PR. Nice that this one gets you 10-20%. |
b7df765
to
2e0bb1d
Compare
I know I'm late to the party - sorry for that. I ran some of the benchmarks from FluxBench, and got about 5% faster runtimes for backward passes on my CPU. I think that's pretty nice, considering that there is also other stuff going on in the models apart from calling RNNs. Unfortunately, I don't have access to a reasonable GPU to try it there.
|
The activations are separate, we shouldn't switch those out for different versions by default I'd imagine. |
Thanks all for the feedback and benchmarking, I think this is good to go. @mcabbott re benchmarking harness, I just created a new file https://gist.github.com/ToucheSir/2fecbe99e5b304fed11e25e42c535cc3#file-rnn-jl under https://github.com/FluxML/Flux.jl/tree/master/perf and ran only those tests locally. A more robust approach would be to use FluxBench for this since it includes RNN models from FluxArchitectures, but I don't recall if that benchmarking can be made to run for a specific PR. |
# AD-friendly helper for dividing monolithic RNN params into equally sized gates | ||
multigate(x::AbstractArray, h, ::Val{N}) where N = ntuple(n -> gate(x,h,n), N) | ||
|
||
@adjoint function multigate(x::AbstractArray, h, c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we strictly need this adjoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the difference between being slower on certain configurations and being strictly faster across all configurations, c.f. before and after. The more calls to gate
, the more pronounced the effect: note how GRU cells called gate 6-8 times and also regressed the most (on smaller input sizes) without multigate
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The better answer would be to see what part of gate
regressed and fixing that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change to gate
itself was one line: https://github.com/FluxML/Flux.jl/pull/1761/files#diff-54816e9a4978b8c02648fdb29ebfd6d794452dbac8a28d0e84a5e2cc646a3fbfR4. Since the view
and getindex
use the same adjoint, there's no reason backwards pass performance should be slower (note forwards pass was consistently faster). Thus the only explanations seem to be a benchmarking artifact (note how this shows up only at smaller input sizes) and/or Zygote's compiler being unhappy for some reason (from profiling, almost all of the non-BLAS, non activation self time is spent in the generated Pullback
for both cases).
However, what it did expose is that calling gate
multiple times regardless of whether it uses view
or slicing was inefficient, as the adjoint would allocate a full-sized buffer for the original array on every call. multigate
resolves this by only allocating once, thus reducing both memory and (accumulation) compute by a factor of the number of gates. Even if gate
wasn't using view
, this would be a worthwhile optimization.
Are we good to go here? Any more concerns/comments? |
bors r+ Error on 1.7 is now:
|
1761: Use view for RNN gate slice extraction r=mcabbott a=ToucheSir This was originally passed over in #907, but I don't find the argument in that PR particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations. Pinging `@jeremiedb,` `@sdobber` and `@mkschleg` if they have any interesting benchmarks to run this on. Otherwise I'll try to get something working with https://github.com/FluxML/Flux.jl/blob/master/perf/bench_utils.jl locally. ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [N/A] Documentation, if applicable - [N/A] API changes require approval from a committer (different from the author, if applicable) Co-authored-by: Brian Chen <ToucheSir@users.noreply.github.com>
Build failed: |
I was trying to investigate that in #1808, hope to get back to it in a few days. |
Once the build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/Flux.jl/previews/PR1761/ in ~20 minutes |
Once the build has completed, you can preview any updated documentation at this URL: https://fluxml.ai/Flux.jl/previews/PR1761/ in ~20 minutes |
bors r+ |
Build succeeded: |
This was originally passed over in #907, but I don't find the argument in that PR particularly compelling as the return value is only ever used once. Any negative impact on caching is going to happen anyhow during the slice materialization, so we might as well just let the subsequent fused broadcasts handle said materialization for us while reducing allocations.
Pinging @jeremiedb, @sdobber and @mkschleg if they have any interesting benchmarks to run this on. Otherwise I'll try to get something working with https://github.com/FluxML/Flux.jl/blob/master/perf/bench_utils.jl locally.
PR Checklist