Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients Through Broadcasts Slow Due to Type Inference Failing? #885

Open
farr opened this issue Jan 18, 2021 · 12 comments
Open

Gradients Through Broadcasts Slow Due to Type Inference Failing? #885

farr opened this issue Jan 18, 2021 · 12 comments

Comments

@farr
Copy link

farr commented Jan 18, 2021

I know this issue has been brought up before, but Zygote is really slow to take gradients through broadcasting, and I have attached a simple example that suggests (at least to my un-expert eye) that the issue may be due to type inference failing in the adjoint for broadcasting. Why it is failing, I have no idea, but perhaps with this simple test case and info an expert has an idea?

Basically, the attached code is a pared-down bit of a model I was trying to fit in Turing.jl. There are two data variables, x and y, which are being fit to a broken power law (for x) and a power law (for y). There are three parameters (two slopes for the broken power law, one for the plain power law). I don't think there is anything special about the functional forms here, just that they are complex enough that Zygote doesn't have any special case code for handling the broadcasted forms when they are applied across arrays. The product of the (log density of the) broken power law and the power law is evaluated on many samples stored in a 2D array and averaged along one axis and then summed over the other. This is a common pattern in "meta analyses" when samples drawn from distributions produced by several analyses of independent data sets are first averaged (approximating an integral over the uncertainty within each analysis) and then the average probabilities multiplied (accumulating the information from the multiple independent data sets).

The upshot is that the log-probability function on floats is simple enough that there is no problem for Julia to infer Float64 through the broadcast, and even with 100s of thousands of x-y samples, completes in ms:

@benchmark logl(0.0, 0.0, 0.0)
# BenchmarkTools.Trial:
#   memory estimate:  1.00 MiB
#   allocs estimate:  9
#   --------------
#   minimum time:     10.252 ms (0.00% GC)
#   median time:      11.140 ms (0.00% GC)
#   mean time:        11.765 ms (1.21% GC)
#   maximum time:     69.750 ms (0.00% GC)
#   --------------
#   samples:          425
#   evals/sample:     1

But the gradient is 500x slower:

grad_logl = (a,b,c) -> gradient(logl, a, b, c)
@benchmark grad_logl(0.0, 0.0, 0.0)
# BenchmarkTools.Trial:
#   memory estimate:  800.01 MiB
#   allocs estimate:  23068864
#   --------------
#   minimum time:     3.803 s (15.01% GC)
#   median time:      4.062 s (20.10% GC)
#   mean time:        4.062 s (20.10% GC)
#   maximum time:     4.322 s (24.59% GC)
#   --------------
#   samples:          2
#   evals/sample:     1

The full code, including environment, etc, is attached. Looking at the gradient computation using the Juno profiler, I see lots of dynamic dispatch calls to the inner computations in the log-likelihood---which is reflected in the huge amount of allocation. But it should be obvious to the complier that the adjoint of these R->R functions is also R->R in the "inner loop" of the broadcast; so somehow this is getting obscured by the generic adjoint code in Zygote.

Hopefully this can help narrow down the poor performance on broadcast and some wizard watching has a good idea for how to get around this; I do a lot of few-tens to 100 parameter model fittings in Stan right now using this sort of code that I would love to move to Turing.jl, but it's too expensive to ForwardDiff them---Zygote should be the solution, but not until I can broadcast without a three-order-of-magnitude slowdown....
BroadcastExample.tar.gz

@DhairyaLGandhi
Copy link
Member

Thanks for bringing this up, I am taking a look now, but seems that the code is also making a lot of use of globals, maybe that also adds to the slowness. The benchmark does a single evaluation, so I am wondering if its counting the compilation cost, either way, I am going to try it out and see figure out what is happening here

@DhairyaLGandhi
Copy link
Member

Found it, seems like all the time is spent in log_logistic_unit

@CarloLucibello
Copy link
Member

I didn't look at the code, but it is a general thing that broadcasting is slow in zygote. You can try to speed it up by defining yourself a diff rule for the broadcasted function, similarly to what is done to NNlib's activations at the end of this file https://github.com/FluxML/NNlib.jl/blob/master/src/activations.jl

@DhairyaLGandhi
Copy link
Member

But not to this extent, this is effectively scalar code that is being produced

@farr
Copy link
Author

farr commented Jan 18, 2021

Yeah, I know that much of the time is spent in log_logistic_unit---but why? In my profiling, it seems that the comparison (x > zero(x)) and the math in log_dNdx are being called via generic dispatch when taking gradients; in the evaluation of the primal, the complier can specialize everything to floats, but somehow in the evaluation of the gradient some pieces of the code become generic in a surprising way.

Re: @CarloLucibello's suggestion---I suppose I could, but I would have to first define the combined function that is getting broadcasted, and give it a name, right? That is, I can't write

log_dN = @. (log_dNdx(xs, a, b) + log_dNdy(ys, c))

but instead have to define something like

function log_dN_full(xs, ys, a, b, c)
    log_dNdx(xs, a, b) + log_dNdy(ys, c)
end

log_dN = log_dN_full.(xs, ys, a, b, c)

and then write for myself

@adjoint broadcasted(typeof(log_dN_full), xs, ys, a, b, c) = log_dN_full.(xs, ys, a, b, c), delta -> (..., ..., ...)

right? That pretty much defeats the entire point of AD for my functions (which are considerably more complex than what I'm giving in this simple example); if I wanted write gradients of my log-likelihoods by hand, I would....

It still feels weird to me that some of the calls inside the adjoint that should easily be inferred to be R->R are instead getting generic dispatch....

@DhairyaLGandhi
Copy link
Member

As I alluded to before, the issue has little to do with broadcasting being slower than we'd like, and more to do with how broadcasting when applied with fusion leads to scalar code being generated. We still need to do better for the case where scalar code can be used in a much more robust way, but Zygote.Forward has most of the rules you'd need for this.

As an experiment, I took the liberty to write some non scalar methods, and as expected, the performance was recouped, running grad_logl in about 100ms on a pretty old CPU.

Notice I have commented out the adjoints, those add some nice memory characteristics more than anything, but aren't necessary for the program to run at all.

using StatsFuns, Random, BenchmarkTools
using Zygote

using Base.Broadcast
using Base.Broadcast: broadcasted

# Zygote.@adjoint function broadcasted(::typeof(exp), x)
#   y = exp.(x)
#   y, Δ -> (nothing, Δ .* y,)
# end
#
# Zygote.@adjoint function broadcasted(::typeof(log1p), xs)
#   log1p.(xs), Δ -> (nothing, Δ ./ (1 .+ xs), )
# end

function log_logistic_unit(x)
    """The log of a sigmoid function that goes from `-1` at `x = -Inf` to `+1`
    at `x = Inf` with unit slope at `x = 0`. """
    if x > zero(x)
        -log1p(exp(-4*x))
    else
        4*x - log1p(exp(4*x))
    end
end

function log_logistic_unit(x::AbstractArray)
  mask = x .> zero(x)
  -log1p.(exp.(-4 .* x)) .* mask .+ .!(mask) .* (4 .* x - log1p.(exp.(4 .* x)))
end

function log_dNdx(x, a, b)
    """The log of a broken power law, with (logarithmic) slope `-a` for `x << 1`
    and `-b` for `x >> 1`.  Normalized to `log(1)` when `x=1`"""
    xbreak = one(x)
    breakscale = xbreak / 10
    y = (x - xbreak) / breakscale

    return logaddexp(log_logistic_unit(-y) - a*log(x), log_logistic_unit(y) - b*log(x))
end

function log_dNdx(x::AbstractArray, a, b)
    xbreak = one.(x)
    breakscale = xbreak ./ 10
    y = (x .- xbreak) ./ breakscale

    return logaddexp.(log_logistic_unit(-y) .- a * log.(x), log_logistic_unit(y) .- b * log.(x))
end

function log_dNdy(y, c)
    c .* log.(y)
end

Random.seed!(969886409174182839)
# try
#     global xs = 10.0.*rand(128, 1024)
#     global ys = rand(128, 1024)
# finally
#     Random.seed!()
# end

function make_loglikelihood(xs, ys)
    function logl(a, b, c)
        log_dN = log_dNdx(xs, a, b) .+ log_dNdy(ys, c)
        sum(logsumexp(log_dN, dims=2))
    end
    logl
end

function main()
  xs = 10.0.*rand(128, 1024)
  ys = rand(128, 1024)
  logl = make_loglikelihood(xs, ys)
  grad_logl(a,b,c) = gradient(logl, a, b, c)
  @btime $grad_logl(0.0, 0.0, 0.0)
end

@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Jan 19, 2021

it seems that the comparison (x > zero(x)) and the math in log_dNdx are being called via generic dispatch when taking gradients

This is somewhat worrying, how are you doing the profiling? It may be that we want to interleave some inference passes, but best to avoid that by producing the correct code while fusing in the first place.

@farr
Copy link
Author

farr commented Jan 19, 2021

Wow! That's amazing---and the performance looks great. Thank you!

Just to make sure I understand: you defined an array version of log_logistic_unit and log_dNdx to sort of "break up" the broadcasting with "intermediate" arrays. These "split" the completely-broadcasted version that I originally wrote up and (I guess?) help with the fusion and (I guess?) provide some intermediate array storage for the backpropagation. Then the only calls to broadcasted(::typeof(some_function), ...) then are "elemental" functions for which Zygote already knows the adjoints. Do I have it right?

I agree that perhaps this points to a problem with fusion for the adjoints.

I was inferring that the calls were generic dispatch by using the Juno profiler, but I think I made a mistake; I'm sorta new to Juno. In my original code the call to

Juno.@profiler grad_logl(0.0, 0.0, 0.0)

generates something like the attached screenshot. The width of the highlighting of the code is proportional to the runtime with that line active on the stack (also the width of the waterfall bars in the plot on the right); the color indicates whether the call is on the stack with a concrete (is that the word?; in blue) or generic dispatch (yellow), and also whether there was a garbage collection on the line (red). I think I was mis-reading the profile output (apologies!)---it looks like there were lots of GCs deep inside my code, but not generic dispatch (except to a couple of high-level functions associated with setting up the adjoints---but these presumably account for a very tiny fraction of the total runtime). Seems like a really neat tool; see http://docs.junolab.org/stable/man/juno_frontend/#Profiler-1 .

image

@DhairyaLGandhi
Copy link
Member

So we aren't breaking up broadcasting, more defining the rules as array operations rather than scalar operations (notice the functions all do broadcast internally as well, as one would expect)

Zygote is pretty efficient on array code, and zygote.forward on scalar code.

You could optimise the code I wrote by avoiding the copies made with multiplying by 4 etc, and notice there are no adjoints that are necessary, because once you have array math in there, reverse mode ad is efficient.

@renatobellotti
Copy link

I am interested in these results. Have they been incorporated into Zygote?

@ToucheSir
Copy link
Member

What exactly were you hoping to be incorporated? There have been a number of changes to broadcasting since this issue was posted, e.g. #1001. Running the original example on Julia 1.8.2:

julia> @benchmark grad_logl(0.0, 0.0, 0.0)
BenchmarkTools.Trial: 282 samples with 1 evaluation.
 Range (min … max):  17.064 ms …  20.028 ms  ┊ GC (min … max): 0.00% … 4.49%
 Time  (median):     17.236 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   17.772 ms ± 950.150 μs  ┊ GC (mean ± σ):  1.92% ± 2.67%

  ██▂ ▃▅                   ▁▅▁                             ▅▃   
  ███▄██▇▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▇███▆▄▁▄▄▄▁▁▄▆▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▄▁▆██▆ ▆
  17.1 ms       Histogram: log(frequency) by time      19.8 ms <

 Memory estimate: 16.01 MiB, allocs estimate: 88.

@renatobellotti
Copy link

Oh, these results look very nice. I didn't have specific requirements, I'm currently profiling some code and thought broadcasting might play a role, but since this has been fixed it is not related to my slow code. Thanks anyway for posting the update!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants