-
Notifications
You must be signed in to change notification settings - Fork 32
RFC: broadcasting #68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Possibly it is worth writing these with a configured |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool
src/stage1/broadcast.jl
Outdated
if T == Bool | ||
# Trivial case | ||
TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...}) | ||
if eltype(T) == Bool |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need this to help with inference?
if eltype(T) == Bool | |
if eltype(T) === Bool |
also maybe want to get a few other common cases?
maybe want a function for it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Further down my hack was ProjectTo(x) isa ProjectTo{<:AbstractZero}
. Maybe that should be formalised, separated / wrapped within CRC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's a good hack
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now wrapped up as JuliaDiff/ChainRulesCore.jl#528
src/stage1/broadcast.jl
Outdated
@@ -78,10 +88,69 @@ function split_bc_rule(f::F, args...) where {F} | |||
end | |||
end | |||
|
|||
# This uses "mulltimap"-like constructs: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# This uses "mulltimap"-like constructs: | |
# This uses "multimap"-like constructs: |
We should steal these into ChainRules.jl
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to move them, if you are happy to depend on StructArrays and hence on a few other things:
https://github.com/JuliaArrays/StructArrays.jl/blob/master/Project.toml#L6-L9
Would be worth figuring out a GPU story. Can this be made to work, or do we need an array-of-tuples fallhack?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ChainRules.jl is allowed dependencies.
ChainRulesCore isn't.
So I think it is fine.
(if we later decide we want them gone we can keep the function name and reimplement how it works)
src/stage1/broadcast.jl
Outdated
(NoTangent(), NoTangent(), dargs...) | ||
end | ||
return ys, back_2 | ||
return ys, length(args)==1 ? back_2_one : back_2_many |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again do we need this to help with inference?
return ys, length(args)==1 ? back_2_one : back_2_many | |
return ys, length(args)===1 ? back_2_one : back_2_many |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is type unstable because it is returning a different function.
can make type stable by moving the if inside the back function
OTOH there are so many other returns that fixing this one wouldn't help.
I wonder if we can think up a design pattern for this kind of case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, good point. The length can be made a type parameter Vararg{N}
though so it ought to be known in advance? The other choices depend on inference, less sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not type stable, although it seems to be the type of the primal which is Any, not the function. I think from the 3rd case, where it's the first part of ∂⃖{1}().(f, args...)
, instead of simply broadcasting. That's also the case which prevents moving the branches inside the pullback.
Now seems to be type-stable, in simple cases I tried. Earlier Any was just from having two variables called y
on different paths!
d7a155a
to
06dd56a
Compare
Codecov Report
@@ Coverage Diff @@
## main #68 +/- ##
==========================================
+ Coverage 57.04% 58.59% +1.54%
==========================================
Files 21 21
Lines 2158 2246 +88
==========================================
+ Hits 1231 1316 +85
- Misses 927 930 +3
Continue to review full report at Codecov.
|
This adds support for
∂⃖{1}
broadcasting. WIP to discuss the approach.The fast path uses
derivatives_given_output
from JuliaDiff/ChainRulesCore.jl#453, which is defined by@scalar_rule
. This allows the forward pass to be simplyz = f.(x, y)
. It's labelled experimental but this is exactly what it's useful for.dx, dy
immediately. StructArrays is already a dependency. This won't work on a GPU right now. Perhaps it should live elsewhere (ChainRules?) and have a GPU fallback.atan.(rand(2,3), rand(2))
) then there is a separate reduction stepunbroadcast
to reduce (e.g.dy = vec(sum(dy, dims=2))
.) That seems a waste but it's not so obvious how to do better.f.(g.(x))
is split into two steps. This is because the rule forg
may need the output, and the rule forf
may need the input, which would not be saved if the forward pass was fused.If there is no scalar rule for
f
, then the slow path is to save an array of pullbacks on the forward pass. The same StructArrays idea lets us save this andz
separately, rather than unzipping an array of tuples.Typically these pullbacks will each close over one or two numbers, so this array is one or two times the size of
z
. This slow path will be used for(f∘g).(x)
, and could probably be used to fusef.(g.(x))
. But the benefit of avoiding an intermediate array on the forward pass must be weighed against storing more numbers in the pullbacks. The fused path is not guaranteed to save any memory.There are special rules for
+,-,*
. These ones do fuse, so thatz = f.(x .+ g.(y))
storesgy = g.(y)
andz
, but re-computesx .+ gy
if the rule forf
needs it, element by element. Perhaps there should be a few more like this. These rules also try to avoidunbroadcast
when possible.All rules are attached to
∂⃖{1}
not torrule
, so that they will not affect other packages using ChainRules. Whether that's ultimately a good idea perhaps depends on whether we think one approach is optimal, or whether different packages ought to make different trade-offs. Right now I'm sure this would break things which work elsewhere. Zygote's approach is completely un-fused, and the fastest path at present uses dual numbers, which requires more memory than the fast path here.Higher derivatives don't really work here, I have not given them any thought. There is an existing rule for
map(f,x)
which should support them, although it seems tricky to invent cases which actually work. For one first-derivative(f∘g).(x)
example below, it was faster than the slow path of this PR.Some benchmarks are below.