-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Faster generic broadcasting #1001
Conversation
dxs = ntuple(len) do i | ||
collapse_nothings(map(StaticGetter{i}(), dxs_zip)) | ||
end | ||
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clearly map(StaticGetter{i}(), dxs_zip)
should really be fused with unbroadcast
, possibly into mapreduce
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some comments but basically seems all sensible.
Once tests pass, merge when happy.
src/lib/broadcast.jl
Outdated
eltype(out) <: Dual || return (out, _ -> nothing) | ||
y = map(x -> x.value, out) | ||
_back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out)) | ||
back(ȳ) = ntuple(i -> _back(ȳ, i), N) | ||
_back(ȳ, geti) = unbroadcast(geti(args), ((a, b) -> a * geti(b.partials)).(ȳ, out)) | ||
back(ȳ) = ntuple(i -> _back(ȳ, StaticGetter{i}()), N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change for CuArrays, not sure this matters. CI doesn't seem to mind. Pushed to try out on another machine, a bit later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some GPU times: no harm, but no help either.
julia> xs = cu(randn(100_000));
julia> @btime CUDA.@sync Zygote.gradient(x -> sum(cbrt.(x)), $xs);
119.805 μs (375 allocations: 10.56 KiB)
119.872 μs (377 allocations: 10.61 KiB) # this PR
julia> @btime CUDA.@sync Zygote.gradient(x -> sum(tanh.(x)), $xs); # has its own gradient
99.938 μs (342 allocations: 10.81 KiB)
julia> @btime CUDA.@sync Zygote.gradient(x -> sum((identity∘tanh).(x)), $xs); # force generic version
116.577 μs (372 allocations: 10.52 KiB)
117.023 μs (371 allocations: 10.52 KiB) # this PR
Also, the generic is as fast as the specific. Remind me why we don't do this on the CPU, and have this 300x slower generic thing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know nothing about how Zygote works with GPU, but taking a guess:
because moving memory between GPU./VRAM and CPU/RAM is even more expensive?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite know why the generic broadcast isn't used for the GPU, perhaps something would break. It uses dual numbers instead, and perhaps this makes purity assumptions about the function involved?
If we very simply call that on the CPU, we get times like this:
julia> @btime Zygote.gradient(x -> sum(abs.(x)), $xs); # earlier 478.458 μs -> 28.625 μs
15.209 μs (20 allocations: 313.16 KiB)
julia> @btime Zygote.gradient(x -> sum((identity∘tanh).(x)), $xs); # earlier 29.680 ms
125.166 μs (20 allocations: 313.16 KiB)
This is with 9227168, which might be too crude... will think a bit & see what CI says.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here's what it gets wrong:
julia> gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5])
([2, 4, 6], nothing) # this commit
([2, 4, 6], [3, 0]) # tagged Zygote
(now caught & sent the slow path)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite know why the generic broadcast isn't used for the GPU, perhaps something would break. It uses dual numbers instead, and perhaps this makes purity assumptions about the function involved
I think the reason is that you can't send IdDict
s to the GPU.
Which means you can't send Zygote.Context
to the GPU.
Which means you can't run Zygote._pullback
, which needs the Context
to track some things, I am not sure what, but I think like tasks and globals?
but the things that the Context is mean to track don't need to be tracked in forwards mode.
With the caveat that I don't really understand what's going on, has this subsumed #807? |
I was unaware of #807. This does not try to make broadcasting better able to handle state; it does try to make it ignore state whenever possible, in the interest of not being slow. The last few commits add a similar test to julia> s = 0;
julia> function f(x)
global s += x
end;
julia> gradient(x -> sum(map(f, x)), 1:10)
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],) # this PR
([10, 9, 8, 7, 6, 5, 4, 3, 2, 1],) # tagged version, test is == (10:-1:1,)
julia> fieldnames(typeof(f))
()
julia> Base.issingletontype(typeof(f))
true Not sure what to make of that, global scope is weird? Maybe the example is too artificial, this is not actually a sensible gradient calculation. The examples I invented (and added as tests here) do detect when a variable is closed over: julia> gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5])
([2, 4, 6], [3, 0])
julia> fieldnames(typeof((z->z^2+y[1]))) # and Base.issingletontype is true
()
julia> let y = [1,2]
fieldnames(typeof((z->z^2+y[1]))) # and Base.issingletontype is false
end
(:y,) |
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped)) | ||
Δf = reduce(accum, Δf_and_args[1]) | ||
(Δf, Δf_and_args[2:end]...) | ||
if Base.issingletontype(F) && length(args) == 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variants of the weird map
example from existing tests. Tests pass, but at the REPL it would fail with this PR; some other variants are fine:
julia> s = 0;
julia> function f(x)
global s += x
end;
julia> gradient(x -> sum(map(f, x)), 1:10)
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],) # this PR
([10, 9, 8, 7, 6, 5, 4, 3, 2, 1],) # tagged version, test is == (10:-1:1,)
julia> fieldnames(typeof(f))
()
julia> Base.issingletontype(typeof(f))
true
julia> S = Ref(0) # mutable, global scope
Base.RefValue{Int64}(0)
julia> function F(x)
S[] += x
end;
julia> fieldnames(typeof(F))
()
julia> gradient(x -> sum(map(F, x)), 1:10)
([1, 2, 3, 4, 5, 6, 7, 8, 9, 10],)
julia> let
S2 = Ref(0) # mutable, local scope
function F2(x)
S2[] += x
end
gradient(x -> sum(map(F2, x)), 1:10)
end
([10, 9, 8, 7, 6, 5, 4, 3, 2, 1],)
julia> let
s2 = 0 # immutable, local scope
function f2(x)
s2 += x
end
gradient(x -> sum(map(f2, x)), 1:10)
end
([10, 9, 8, 7, 6, 5, 4, 3, 2, 1],)
My inclination is to say that none of this matters -- the motivation for reverse
in stateful map is RNNs, where f
is a whole network. That will definitely contain parameters, and hence go the slow path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, probably we could afford a regression in those very corner cases. I don't know why things in the Main module act so differently. Maybe we could exclude methods defined in Main from the fast path? (I don't like this idea too much though)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea there might be a way to test for this, but parentmodule(f2)
is also Main
so I don't think that's quite right.
But the example seems pretty contrived. We're concerned about getting gradients right, which sometimes depends on evaluation order. This is a naked test of evaluation order, as a proxy for the real concern. Perhaps it's not a close enough one, and we should test slightly more realistic things.
More generally I'm not sure Julia makes guarantees about the order of evaluation of map
or broadcast
. Especially for broadcast
. I'm not too sure what that means for AD yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, then I think I'm fine with merging this PR as it is
@mcabbott please address the concerns brought up and please use bors to merge. |
What concerns? |
My comment above would be a start. |
Yes that's why I asked, I have no idea what that comment was meant to say. Apparently the people tagged also don't know. If you have actual technical concerns and can explain them, or at least provide a MWE, then I'm happy to have a look. |
How would this support immutable structures like static arrays without redoing work of writing kernels for every kind of array type that comes along? |
Julia just specialises like usual. It's not mutating anything. What makes you think this is less likely to work on immutable arrays than the previous code? There are right now zero tests for StaticArrays. But when I try things, it seems pretty easy to find examples where broadcasting fails, on the tagged version. |
@mcabbott , how did you (meaning, how can I :-) ) select forward mode in Zygote - was this done using |
The merged version will automatically use forward mode whenever possible. The times quoted were from different variants tried while writing this PR -- the first step was Applying |
Thanks, @mcabbott !
Oh, sure - I meant for the function broadcasted over. |
Ah OK. I never thought to try that. In principle that ought to compile down to roughly the same thing, right? But in practice:
|
Oh, interesting! The Zygote compiler now outperforms
I guess |
I hadn't seen this package, thanks. Will have to understand what it does, to save memory. Xref also JuliaDiff/ChainRules.jl#531, where there are some ideas on how to use For now Zygote uses dual numbers a lot... well the list of conditions is fairly long. The arguments and output must be real numbers. You shouldn't be inside a second derivative. The function must not be a closure, for which
Do you mean fusion, or just a complicated Zygote remains always un-fused, i.e. |
Looks like for the julia> using Zygote, ForwardDiffPullbacks, StaticArrays, BenchmarkTools
julia> f = (xs...) -> (sum(map(x -> sum(map(x -> x^2, x)), xs)))
julia> xs = (2, (3, 4), SVector(5, 6, 7));
julia> Xs = map(x -> fill(x, 100), xs);
julia> @benchmark Zygote.gradient(($Xs...) -> sum(f.($Xs...)), $Xs...)
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 403.472 μs … 5.890 ms ┊ GC (min … max): 0.00% … 91.96%
Time (median): 452.634 μs ┊ GC (median): 0.00%
Time (mean ± σ): 474.041 μs ± 294.882 μs ┊ GC (mean ± σ): 3.51% ± 5.18%
Memory estimate: 143.06 KiB, allocs estimate: 3229.
julia> @benchmark Zygote.gradient(($Xs...) -> sum(fwddiff($f).($Xs...)), $Xs...)
BenchmarkTools.Trial: 10000 samples with 7 evaluations.
Range (min … max): 4.543 μs … 393.863 μs ┊ GC (min … max): 0.00% … 95.79%
Time (median): 5.010 μs ┊ GC (median): 0.00%
Time (mean ± σ): 5.637 μs ± 12.513 μs ┊ GC (mean ± σ): 8.13% ± 3.62%
Memory estimate: 9.11 KiB, allocs estimate: 64. Whew, not all my work was in vain - for now .... ;-) |
Just an f that has many steps, loops, etc., stuff that will result in a long tape if Zygote reverse-mode diffs it. |
ForwardDiffPullbacks.jl basically "just" provides an The main idea behind ForwardDiffPullbacks.jl is to allow users to force forward-mode AD, compatible with anything the respects |
I see, so I should think of this as being a bit like
Well I certainly have no plans to make Zygote's broadcasting subsume this, doing stuff beyond arrays of numbers I mean. But I would like to figure out how Xref also JuliaDiff/ChainRulesDeclarationHelpers.jl#2 , not the same but perhaps of interest. |
OK now I see -- it looks like it evaluates twice, which avoids having to save the partials between the forward pass and the backward. julia> using Zygote, ForwardDiffPullbacks
julia> pr(x) = @show x;
julia> gradient(fwddiff(pr), 2pi)
x = 6.283185307179586
x = Dual{ForwardDiff.Tag{Tuple{typeof(pr), Val{1}}, Float64}}(6.283185307179586,1.0)
(1.0,)
julia> gradient(x -> Zygote.forwarddiff(pr,x), 2pi)
x = Dual{Nothing}(6.283185307179586,1.0)
(1.0,) |
Yes - it's a bit wasteful, but the forward pass is without dual numbers and the reverse pass uses separate duals for each thunk. I should optimize that, at least for the case where there's only one arguments. |
I'm very open to changes/suggestions/PRs regarding ForwardDiffPullbacks - I'd be glad to move it to the JuliaDiff GitHub org once it's considered solid enough. |
For broadcasting it seems tricky to guess how to make this trade-off, in general. Some functions are more expensive than others. But for an explicit wrapper like this, you know what you're getting... I suppose it could in principle offer an option to specify whether to re-compute or not. For numbers, as you say, a chunked calculation seems like it should not have downsides.
|
Zygote uses ForwardDiff for differentiating broadcasting operations. See FluxML/Zygote.jl#1001. However ForwardDiff cannot differentiate non-generic code that accepts only Float64 (it needs to pass its Dual type). In addition ForwardDiff defines rules via DiffRules. It doesn't understand ChainRules. This means we need to define some rules here to make things work. See discussion here to understand how to add new rules for ForwardDiff: https://discourse.julialang.org/t/issue-with-forwarddiff-custom-ad-rule/72886/6 If JuliaDiff/DiffRules.jl#74 gets merged I should not need this anymore.
During JuliaDiff/ChainRules.jl#441 we realised there's probably a factor 10 left on the table, for no good reason. Not always a factor 10, but still:
Edit -- latest version does what's suggested in #336, roughly, marked "forward" above. That is, using dual numbers (via the existing internal function
broadcast_forward
) instead of generic reverse-mode when possible. Examples linked there discourse & #592 both get about 5x faster.The lines marked "reverse" above are the result of inference improvements to the generic reverse mode. These were the first step of this PR. And will still be used when dual numbers cannot be, e.g. on complex numbers, or broadcasting closures.
Edit' -- perhaps worth timing
map
on the same problems. Not certain we should do this here, but the same purity check used for "forward" should also make it safe to skipreverse
stuff inmap
. Which gives modest speedups & some memory savings: