Skip to content

RFC: broadcasting #68

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed

RFC: broadcasting #68

wants to merge 9 commits into from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 9, 2022

This adds support for ∂⃖{1} broadcasting. WIP to discuss the approach.

The fast path uses derivatives_given_output from JuliaDiff/ChainRulesCore.jl#453, which is defined by @scalar_rule. This allows the forward pass to be simply z = f.(x, y). It's labelled experimental but this is exactly what it's useful for.

  1. The reverse pass uses StructArrays so that rather than allocating an array of tuples and unzipping, it makes arrays dx, dy immediately. StructArrays is already a dependency. This won't work on a GPU right now. Perhaps it should live elsewhere (ChainRules?) and have a GPU fallback.
  2. If the arguments are not the same size (e.g. atan.(rand(2,3), rand(2))) then there is a separate reduction step unbroadcast to reduce (e.g. dy = vec(sum(dy, dims=2)).) That seems a waste but it's not so obvious how to do better.
  3. It is not fused: f.(g.(x)) is split into two steps. This is because the rule for g may need the output, and the rule for f may need the input, which would not be saved if the forward pass was fused.

If there is no scalar rule for f, then the slow path is to save an array of pullbacks on the forward pass. The same StructArrays idea lets us save this and z separately, rather than unzipping an array of tuples.

Typically these pullbacks will each close over one or two numbers, so this array is one or two times the size of z. This slow path will be used for (f∘g).(x), and could probably be used to fuse f.(g.(x)). But the benefit of avoiding an intermediate array on the forward pass must be weighed against storing more numbers in the pullbacks. The fused path is not guaranteed to save any memory.

There are special rules for +,-,*. These ones do fuse, so that z = f.(x .+ g.(y)) stores gy = g.(y) and z, but re-computes x .+ gy if the rule for f needs it, element by element. Perhaps there should be a few more like this. These rules also try to avoid unbroadcast when possible.

All rules are attached to ∂⃖{1} not to rrule, so that they will not affect other packages using ChainRules. Whether that's ultimately a good idea perhaps depends on whether we think one approach is optimal, or whether different packages ought to make different trade-offs. Right now I'm sure this would break things which work elsewhere. Zygote's approach is completely un-fused, and the fastest path at present uses dual numbers, which requires more memory than the fast path here.

Higher derivatives don't really work here, I have not given them any thought. There is an existing rule for map(f,x) which should support them, although it seems tricky to invent cases which actually work. For one first-derivative (f∘g).(x) example below, it was faster than the slow path of this PR.

Some benchmarks are below.

# add https://github.com/mcabbott/Diffractor.jl#splitcast

julia> using Diffractor, Zygote, Tracker, BenchmarkTools
julia> dgrad(x...) = Diffractor.unthunk.(Diffractor.gradient(x...));
julia> zgrad(x...) = Zygote.gradient(x...);
julia> tgrad(x...) = Tracker.gradient(x...);

julia> xs = randn(10_000);
julia> @btime copy($xs);  # times showing mean are January, times without were from September
  min 1.262 μs, mean 2.397 μs (2 allocations, 78.17 KiB)

# Time nothing first, these are offsets unrelated to broadcasting

julia> @btime tgrad(x -> sum(abs2, x), $xs);
  min 19.417 μs, mean 37.570 μs (52 allocations, 548.49 KiB)
julia> @btime zgrad(x -> sum(abs2, x), $xs);
  min 3.281 μs, mean 4.909 μs (2 allocations, 78.17 KiB)
julia> @btime dgrad(x -> sum(abs2, x), $xs);
  min 3.827 μs, mean 5.099 μs (12 allocations, 78.56 KiB)


# Simple function

julia> @btime tgrad(x -> sum(abs2, exp.(x)), $xs);
  min 96.250 μs, mean 110.523 μs (119 allocations, 708.23 KiB) # 2 copies: (708.23 - 548.49)/78.17  is 

julia> @btime zgrad(x -> sum(abs2, exp.(x)), $xs);
  min 50.375 μs, mean 63.084 μs (31 allocations, 391.50 KiB)  # with dual numbers -- like 4 copies
  min 49.708 μs, mean 88.805 μs (38 allocations, 469.75 KiB)  # without, thus Zygote.pullback

julia> @btime dgrad(x -> sum(abs2, exp.(x)), $xs);
  min 49.333 μs, mean 70.513 μs (39 allocations, 235.52 KiB)  # fast path -- one copy forward, one back
  min 49.958 μs, mean 68.664 μs (35 allocations, 313.53 KiB)  # slow path -- 3 copies, extra is the closures?
  min 61.167 μs, mean 73.827 μs (27 allocations, 703.94 KiB)  # with `map` rule as before -- worse


# Composed function, Zygote struggles without dual numbers

julia> @btime tgrad(x -> sum(abs2, (identitycbrt).(x)), $xs);
  min 132.583 μs, mean 185.010 μs (119 allocations, 708.23 KiB)

julia> @btime zgrad(x -> sum(abs2, (identitycbrt).(x)), $xs);
  min 70.375 μs, mean 102.529 μs (31 allocations, 391.64 KiB)   # with dual numbers
  min 39.823 ms, mean 41.298 ms (539576 allocations, 13.73 MiB) # without, thus Zygote.pullback

julia> @btime dgrad(x -> sum(abs2, (identitycbrt).(x)), $xs);  # now fails, https://github.com/JuliaDiff/Diffractor.jl/issues/67 . 
  55.290 ms (830060 allocations: 49.75 MiB)  # slow path
  14.747 ms (240043 allocations: 7.25 MiB)   # with `map` rule as before -- better!


# Compare same problem unfused

julia> @btime tgrad(x -> sum(abs2, identity.(cbrt.(x))), $xs);
  min 132.542 μs, mean 182.825 μs (119 allocations, 708.23 KiB)

julia> @btime zgrad(x -> sum(abs2, identity.(cbrt.(x))), $xs);
  min 69.250 μs, mean 101.403 μs (31 allocations, 391.50 KiB)

julia> @btime dgrad(x -> sum(abs2, identity.(cbrt.(x))), $xs);
  min 70.875 μs, mean 106.793 μs (52 allocations, 392.12 KiB)  # fast path -- two copies forward, two back
  min 78.209 μs, mean 114.250 μs (48 allocations, 470.14 KiB)  # slow path -- 5 copies
  min 131.708 μs, mean 228.664 μs (39 allocations, 1.30 MiB)   # with `map` rule as before -- worse


# Lazy +,-,* for partial fusing

julia> @btime zgrad(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
  min 61.709 μs, mean 109.499 μs (21 allocations, 625.47 KiB)  # special rules, 4 more copies than best

julia> @btime dgrad(x -> sum(abs2, exp.(2 .* x .- 100)), $xs);
  min 95.916 μs, mean 162.643 μs (100 allocations, 862.73 KiB) # without special rules
  min 58.583 μs, mean 85.664 μs (49 allocations, 314.11 KiB)   # with lazy rules
  min 72.917 μs, mean 155.654 μs (39 allocations, 1016.78 KiB) # with `map` rule as before


julia> @btime tgrad((x,y) -> sum(abs2, exp.(2 .* x .+ y)), $xs, $(rand(10)'));
  min 1.926 ms, mean 2.286 ms (193 allocations, 6.42 MiB)

julia> @btime zgrad((x,y) -> sum(abs2, exp.(2 .* x .+ y)), $xs, $(rand(10)'));
  min 523.792 μs, mean 963.075 μs (55 allocations, 4.81 MiB) # special rules, fully un-fused

julia> @btime dgrad((x,y) -> sum(abs2, exp.(2 .* x .+ y)), $xs, $(rand(10)'));
  min 856.708 μs, mean 1.222 ms (136 allocations, 5.04 MiB)  # without special rules
  min 657.750 μs, mean 860.824 μs (91 allocations, 2.44 MiB) # with lazy rules
  min 921.750 μs, mean 1.381 ms (69 allocations, 7.94 MiB)   # with `map` rule as before


# Simple closure, in order of speed:

julia> @btime ForwardDiff.derivative(x -> sum((y -> y/x).($xs)), pi/2)
  min 6.025 μs, mean 18.328 μs (2 allocations, 156.30 KiB)

julia> @btime zgrad(x -> sum((y -> y/x).($xs)), pi/2)
  min 37.625 μs, mean 88.717 μs (54 allocations, 861.20 KiB)

julia> @btime tgrad(x -> sum((y -> y/x).($xs)), pi/2)
  min 2.912 ms, mean 3.448 ms (200004 allocations, 5.26 MiB)

julia> @btime dgrad(x -> sum((y -> y/x).($xs)), pi/2)  # previously AssertionError
  min 7.655 ms, mean 8.476 ms (270061 allocations, 8.17 MiB)


# Won't work on GPU

julia> using StructArrays, CUDA
julia> splitcast(f, args...) = StructArrays.components(StructArray(Broadcast.instantiate(Broadcast.broadcasted(f, args...))));
julia> splitmap(f, args...) = StructArrays.components(StructArray(Iterators.map(f, args...)));

julia> splitcast((x,y) -> (x/y, x+y), cu([1,2,3]), cu([4,5,6]))
┌ Warning: Performing scalar indexing on task Task (runnable) @0x00007f09cf37c010.
...
([0.25, 0.4, 0.5], [5, 7, 9])

julia> ans[1]
3-element Vector{Float64}:
 0.25
 0.4
 0.5

julia> splitmap((x,y) -> (x/y, x+y), cu([1,2,3]), cu([4,5,6]))[1]
3-element Vector{Float64}:
 0.25
 0.4
 0.5

@oxinabox
Copy link
Member

All rules are attached to ∂⃖{1} not to rrule, so that they will not affect other packages using ChainRules. Whether that's ultimately a good idea perhaps depends on whether we think one approach is optimal, or whether different packages ought to make different trade-offs. Right now I'm sure this would break things which work elsewhere. Zygote's approach is completely un-fused, and the fastest path at present uses dual numbers, which requires more memory than the fast path here.

Possibly it is worth writing these with a configured rrule set to apply only to the DiffractorADConfig.
That way we can later port them over easily if we want.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool

if T == Bool
# Trivial case
TΔ = Core.Compiler._return_type(derivatives_given_output, Tuple{T, F, map(eltype, args)...})
if eltype(T) == Bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this to help with inference?

Suggested change
if eltype(T) == Bool
if eltype(T) === Bool

also maybe want to get a few other common cases?
maybe want a function for it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further down my hack was ProjectTo(x) isa ProjectTo{<:AbstractZero}. Maybe that should be formalised, separated / wrapped within CRC?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's a good hack

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -78,10 +88,69 @@ function split_bc_rule(f::F, args...) where {F}
end
end

# This uses "mulltimap"-like constructs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# This uses "mulltimap"-like constructs:
# This uses "multimap"-like constructs:

We should steal these into ChainRules.jl

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to move them, if you are happy to depend on StructArrays and hence on a few other things:
https://github.com/JuliaArrays/StructArrays.jl/blob/master/Project.toml#L6-L9

Would be worth figuring out a GPU story. Can this be made to work, or do we need an array-of-tuples fallhack?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ChainRules.jl is allowed dependencies.
ChainRulesCore isn't.

So I think it is fine.
(if we later decide we want them gone we can keep the function name and reimplement how it works)

(NoTangent(), NoTangent(), dargs...)
end
return ys, back_2
return ys, length(args)==1 ? back_2_one : back_2_many
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again do we need this to help with inference?

Suggested change
return ys, length(args)==1 ? back_2_one : back_2_many
return ys, length(args)===1 ? back_2_one : back_2_many

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is type unstable because it is returning a different function.
can make type stable by moving the if inside the back function
OTOH there are so many other returns that fixing this one wouldn't help.
I wonder if we can think up a design pattern for this kind of case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, good point. The length can be made a type parameter Vararg{N} though so it ought to be known in advance? The other choices depend on inference, less sure.

Copy link
Member Author

@mcabbott mcabbott Jan 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not type stable, although it seems to be the type of the primal which is Any, not the function. I think from the 3rd case, where it's the first part of ∂⃖{1}().(f, args...), instead of simply broadcasting. That's also the case which prevents moving the branches inside the pullback.

Now seems to be type-stable, in simple cases I tried. Earlier Any was just from having two variables called y on different paths!

@codecov-commenter
Copy link

codecov-commenter commented Jan 22, 2022

Codecov Report

Merging #68 (06dd56a) into main (be4eeb5) will increase coverage by 1.54%.
The diff coverage is 75.24%.

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #68      +/-   ##
==========================================
+ Coverage   57.04%   58.59%   +1.54%     
==========================================
  Files          21       21              
  Lines        2158     2246      +88     
==========================================
+ Hits         1231     1316      +85     
- Misses        927      930       +3     
Impacted Files Coverage Δ
src/stage1/generated.jl 75.56% <0.00%> (-1.15%) ⬇️
src/extra_rules.jl 58.18% <25.00%> (+10.95%) ⬆️
src/stage1/broadcast.jl 75.00% <80.43%> (+20.71%) ⬆️
src/tangent.jl 60.20% <0.00%> (-5.11%) ⬇️
src/stage1/forward.jl 86.98% <0.00%> (-4.11%) ⬇️
src/stage1/recurse_fwd.jl 94.28% <0.00%> (-2.89%) ⬇️
src/stage1/recurse.jl 96.52% <0.00%> (+2.95%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update bc22ad6...06dd56a. Read the comment docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants