Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document Loop Optimisation Opportunities #156

Open
willtebbutt opened this issue May 15, 2024 · 4 comments
Open

Document Loop Optimisation Opportunities #156

willtebbutt opened this issue May 15, 2024 · 4 comments
Labels
documentation Improvements or additions to documentation enhancement (performance) Would reduce the time it takes to run some bit of the code

Comments

@willtebbutt
Copy link
Member

No description provided.

@willtebbutt willtebbutt added documentation Improvements or additions to documentation enhancement New feature or request labels May 15, 2024
@willtebbutt willtebbutt added enhancement (performance) Would reduce the time it takes to run some bit of the code and removed enhancement New feature or request labels May 28, 2024
@willtebbutt
Copy link
Member Author

willtebbutt commented May 30, 2024

I still need to add more concrete info about what loop optimisations are possible, but here's a summary of the state of affairs currently:

  • map, broadcast, mapreduce, and any other higher-order functions I’ve forgotten about, all lower to loops in the CFG. Tapir.jl doesn’t have rules for them, so Tapir.jl sees these loops.
  • Loops in Tapir.jl are reasonably performant. For example, they are completely type-stable and are (usually) allocation-free. So if you’re doing a bit of work at each iteration (e.g. sin(cos(exp(x[n]))) ), the time spent managing “overhead” associated to looping (e.g. logging stuff on the forwards pass at each iteration which you need on the reverse-pass) is small in comparison to the time spent doing the work that you care about (e.g. computing sin(cos(exp(x[n]))) and doing AD on each operation in it etc)
  • If you’re doing a very small amount of work at each iteration of a loop, then your computation is (currently) dominated by “overhead”. sum is an extreme case of this, because adding two Float64s together at each iteration is about the cheapest differentiable operation that you could imagine doing. Moreover, the current way that we handle looping in Tapir.jl “gets in the way” of vectorisation (on the forwards-pass and reverse-pass).
  • The loop optimisations that I will discuss in the issue will largely target this overhead. They will therefore improve the performance of every single example in this table, but the largest improvements will be seen for kron and sum. I imagine they’ll be especially great on sum as they should be able to “get out of the way” of vectorisation (i.e. things should vectorise nicely) in many cases.

That we just rely on everything boiling down to the same kind of looping structure in the CFG is a great advantage of this approach -- basically everything CPU-based that’s performant gets reduced to a loop in the CFG (specifically, a thing called a “Natural Loop” in compiler optimisation terminology). There are well-established optimisation strategies for loops, so we don’t need to implement separate rules for all the different higher-order functions to get good performance, nor do we need to tell people to steer clear of writing for or while loops.
Rather, we just optimise these so-called “natural loop” structures which appear in the CFG, and then everything (or, rather, most things) will (should) be fast.

(The situation in which this strategy breaks down is if people use @goto to produce certain kinds of “weird” looping structures. Such structures will only ever be as performant as they are currently. Frankly, it’s not bad, but we should probably discourage people from using @goto , which is definitely something that I can live with)

@willtebbutt
Copy link
Member Author

willtebbutt commented Jul 1, 2024

Tapir.jl does not perform as well as it could on functions like the following:

function foo!(y::Vector{Float64}, x::Vector{Float64})
    @inbounds @simd for n in eachindex(x)
        y[n] = y[n] + x[n]
    end
    return y
end

For example, on my computer:

y = randn(4096)
x = randn(4096)

julia> @benchmark foo!($y, $x)
BenchmarkTools.Trial: 10000 samples with 173 evaluations.
 Range (min  max):  547.150 ns    3.138 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     646.633 ns               ┊ GC (median):    0.00%
 Time  (mean ± σ):   682.488 ns ± 116.548 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

       ▄██▂                                                      
  ▁▁▂▄▇████▇▇▇▆▆▅▅▅▅▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▂
  547 ns           Histogram: frequency by time         1.18 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

rule = Tapir.build_rrule(foo!, y, x);
foo!_d = zero_fcodual(foo!)
y_d = zero_fcodual(y)
x_d = zero_fcodual(x)
out, pb!! = rule(foo!_d, y_d, x_d);

julia> @benchmark ($rule)($foo!_d, $y_d, $x_d)[2](NoRData())
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min  max):  64.042 μs  202.237 μs  ┊ GC (min  max): 0.00%  0.00%
 Time  (median):     78.675 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   75.763 μs ±  10.175 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%

  ▇ ▇ ▇ ▄▂       ▅ ▃  ▆ ▅▆ █▄ ▄  ▁▂         ▂       ▁        ▂ ▃
  █▃█▃█▄██▁▄▁▆▄▁▁█▄█▆██████████▇▄██▃▅█▃▃█▆▆▇█▆▆▆▇▆▄▁█▃▃▃▅█▅▃▁█ █
  64 μs         Histogram: log(frequency) by time       108 μs <

 Memory estimate: 0 bytes, allocs estimate: 0.

So the performance ratio is roughly 64 / 0.5 which is 128.

Note that this is not due to type-instabilities. One way to convince yourself of this is that there are no allocations required to run AD, which would most certainly not be the case were there type instabilities.
Rather, the problems are to do with the overhead associated to our implementation of reverse-mode AD.

To see this, take a look at the optimised IR for foo!:

2 1 ── %1  = Base.arraysize(_3, 1)::Int64                                    │╻╷╷╷╷    macro expansion
  │    %2  = Base.slt_int(%1, 0)::Bool                                       ││╻╷╷╷╷    eachindex
  │    %3  = Core.ifelse(%2, 0, %1)::Int64                                   │││╻        axes1
  │    %4  = %new(Base.OneTo{Int64}, %3)::Base.OneTo{Int64}                  ││││┃││││    axes
  └───       goto #14 if not true                                            │╻        macro expansion
  2 ── %6  = Base.slt_int(0, %3)::Bool                                       ││╻        <
  └───       goto #12 if not %6                                              ││       
  3 ──       nothing::Nothing4 ┄─ %9  = φ (#3 => 0, #11 => %27)::Int64                                  ││       %10 = Base.slt_int(%9, %3)::Bool                                      ││╻        <
  └───       goto #12 if not %10                                             ││       
  5 ── %12 = Base.add_int(%9, 1)::Int64                                      ││╻╷       simd_index
  └───       goto #9 if not false                                            │││╻        getindex
  6 ── %14 = Base.slt_int(0, %12)::Bool                                      ││││╻        >%15 = Base.sle_int(%12, %3)::Bool                                     ││││╻        <=%16 = Base.and_int(%14, %15)::Bool                                    ││││╻        &
  └───       goto #8 if not %16                                              ││││     
  7 ──       goto #9                                                         │        
  8 ──       invoke Base.throw_boundserror(%4::Base.OneTo{Int64}, %12::Int64)::Union{}
  └───       unreachable                                                     ││││     
  9 ┄─       goto #10                                                        │        
  10 ─       goto #11                                                        │        
  11%23 = Base.arrayref(false, _2, %12)::Float64                          ││╻╷       macro expansion
  │    %24 = Base.arrayref(false, _3, %12)::Float64                          │││┃        getindex
  │    %25 = Base.add_float(%23, %24)::Float64                               │││╻        +
  │          Base.arrayset(false, _2, %25, %12)::Vector{Float64}             │││╻        setindex!
  │    %27 = Base.add_int(%9, 1)::Int64                                      ││╻        +$(Expr(:loopinfo, Symbol("julia.simdloop"), nothing))::Nothing  │╻        macro expansion
  └───       goto #4                                                         ││       
  12 ┄       goto #14 if not false                                           ││       
  13nothing::Nothing5 14return _2                                                       │        

The performance-critical chunk of the loop happens between %23 and %27. Tapir.jl does basically the same kind of thing for each of these lines, so we just look at %23:

%23_ = rrule!!(zero_fcodual(Base.arrayref), zero_fcodual(false), _2, %12)
%23 = %23[1]
push!(%23_pb_stack, %23[2])

In short, we run the rule, pull out the first element of the result, and push the pullback to the stack for use on the reverse-pass.

So there is at least one really large obvious source of overhead here: pushing to / popping from the stacks. If you take a look at the pullbacks for the arrayref calls, you'll see that they contain:

  1. (a reference to) the shadow of the array being referenced, and
  2. a copy of the index at which the forwards-pass references the array.

This information is necessary for AD, but

  1. the array being referenced and its shadow are loop invariants -- their value does not change at each iteration of the loop -- meaning that we're just pushing 4096 references to the same array to a stack and popping them, which is wasteful, and
  2. the index is an induction variable -- its value changes by a fixed known amount at each loop iteration, meaning that (in principle) we can just recompute it on the reverse-pass rather than storing it.

What's not obvious here, but is also important, is that the call to push! tends to get inlined and contains a branch. This prevents LLVM from vectorising the loop, thus prohibiting quite a lot of optimisation.

Now, Tapir.jl is implemented in such a way that, if the pullback for a particular function is a singleton / doesn't carry around any information, the associated pullback stack is eliminated entirely. Moreover, just reducing the amount of memory stored at each iteration should reduce memory pressure. Consequently, a good strategy for making progress is to figure out how to reduce the amount of stuff which gets stored in the pullback stacks. The two points noted above provide obvious starting points.

Making use of loop invariants

In short: ammend the rule interface such that the arguments to the forwards pass are also made available on the reverse pass.

For example, the arrayref rule is presently something along the lines of

function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
    _ind = primal(ind)
    dx = tangent(x)
    function arrayref_pullback(dy)
        dx[_ind] += dy
        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end

This skips some details, but the important point is that _ind and dx are closed over, and are therefore stored in arrayref_pullback.

Under the new interface, this would look something like

function rrule!!(::CoDual{typeof(arrayref)}, inbounds::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
    function arrayref_pullback(dy, ::CoDual{typeof(arrayref)}, ::CoDual{Bool}, x::CoDual{Vector{Float64}}, ind::CoDual{Int})
        _ind = primal(ind)
        dx = tangent(x)
        dx[_ind] += dy
        return NoRData(), NoRData(), NoRData(), NoRData()
    end
    return CoDual(primal(x)[_ind], tangent(x)[_ind]), arrayref_pullback
end

In this version of the rule, arrayref_pullback is a singleton because it does not close over any data from the enclosing rrule!!.

So this interface change frees up Tapir.jl to provide the arguments on the reverse-pass in whichever way it pleases. In this particular example, both x and y are arguments to foo!, so applying this new interface recursively would give us direct access to them on the reverse pass by construction.
A similar strategy could be employed for variables which aren't arguments by putting them in the storage shared by the forwards and reverse passes.

It's impossible to know for sure how much of an effect this would have, but doing this alone would more than halve the memory requirement for arrayref (a Vector{Float64} knows its address in memory and its length, which requires 16B of memory, vs an index which is just an Int which takes 8B of memory), and do even more for arrayset (it requires references to the primal array and to the shadow). Since the pullback for + is already a singleton in both the Float64 and Int case, this would more than halve the memory footprint of the loop.

Induction Variable Analysis

I won't address how we could make use of induction variable analysis here because I'm still trying to get my head around exactly how is easiest to go about it.
Rather, just note that the above interface change is necessary in order to make use of the results of induction variable analysis -- the purpose of induction variable analysis would be to avoid having to store the index on each iteration of the loop, and to just re-compute it on the reverse pass, and give it to the pullbacks. The above change to the interface would permit this.

@willtebbutt
Copy link
Member Author

Another obvious optimisation is to analyse the trip count, and pre-allocate the (necessary) pullback stacks in order to avoid branching during execution (i.e. checking that they're long enough to store the next pullback, and allocating more memory if not).

This is related to induction variable analysis, so we'd probably want to do that first.

Doing this kind of optimisation would enable vectorisation to happen more effectively in AD, as would could completely eliminate branching from a number of tight loops.

@yebai
Copy link
Contributor

yebai commented Jul 1, 2024

Good investigations; it's probably okay to keep this issue open instead of transferring discussions here into docs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement (performance) Would reduce the time it takes to run some bit of the code
Projects
None yet
Development

No branches or pull requests

2 participants