-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradients Through Broadcasts Slow Due to Type Inference Failing? #885
Comments
Thanks for bringing this up, I am taking a look now, but seems that the code is also making a lot of use of globals, maybe that also adds to the slowness. The benchmark does a single evaluation, so I am wondering if its counting the compilation cost, either way, I am going to try it out and see figure out what is happening here |
Found it, seems like all the time is spent in |
I didn't look at the code, but it is a general thing that broadcasting is slow in zygote. You can try to speed it up by defining yourself a diff rule for the broadcasted function, similarly to what is done to NNlib's activations at the end of this file https://github.com/FluxML/NNlib.jl/blob/master/src/activations.jl |
But not to this extent, this is effectively scalar code that is being produced |
Yeah, I know that much of the time is spent in Re: @CarloLucibello's suggestion---I suppose I could, but I would have to first define the combined function that is getting broadcasted, and give it a name, right? That is, I can't write log_dN = @. (log_dNdx(xs, a, b) + log_dNdy(ys, c)) but instead have to define something like function log_dN_full(xs, ys, a, b, c)
log_dNdx(xs, a, b) + log_dNdy(ys, c)
end
log_dN = log_dN_full.(xs, ys, a, b, c) and then write for myself @adjoint broadcasted(typeof(log_dN_full), xs, ys, a, b, c) = log_dN_full.(xs, ys, a, b, c), delta -> (..., ..., ...) right? That pretty much defeats the entire point of AD for my functions (which are considerably more complex than what I'm giving in this simple example); if I wanted write gradients of my log-likelihoods by hand, I would.... It still feels weird to me that some of the calls inside the adjoint that should easily be inferred to be R->R are instead getting generic dispatch.... |
As I alluded to before, the issue has little to do with broadcasting being slower than we'd like, and more to do with how broadcasting when applied with fusion leads to scalar code being generated. We still need to do better for the case where scalar code can be used in a much more robust way, but Zygote.Forward has most of the rules you'd need for this. As an experiment, I took the liberty to write some non scalar methods, and as expected, the performance was recouped, running Notice I have commented out the adjoints, those add some nice memory characteristics more than anything, but aren't necessary for the program to run at all. using StatsFuns, Random, BenchmarkTools
using Zygote
using Base.Broadcast
using Base.Broadcast: broadcasted
# Zygote.@adjoint function broadcasted(::typeof(exp), x)
# y = exp.(x)
# y, Δ -> (nothing, Δ .* y,)
# end
#
# Zygote.@adjoint function broadcasted(::typeof(log1p), xs)
# log1p.(xs), Δ -> (nothing, Δ ./ (1 .+ xs), )
# end
function log_logistic_unit(x)
"""The log of a sigmoid function that goes from `-1` at `x = -Inf` to `+1`
at `x = Inf` with unit slope at `x = 0`. """
if x > zero(x)
-log1p(exp(-4*x))
else
4*x - log1p(exp(4*x))
end
end
function log_logistic_unit(x::AbstractArray)
mask = x .> zero(x)
-log1p.(exp.(-4 .* x)) .* mask .+ .!(mask) .* (4 .* x - log1p.(exp.(4 .* x)))
end
function log_dNdx(x, a, b)
"""The log of a broken power law, with (logarithmic) slope `-a` for `x << 1`
and `-b` for `x >> 1`. Normalized to `log(1)` when `x=1`"""
xbreak = one(x)
breakscale = xbreak / 10
y = (x - xbreak) / breakscale
return logaddexp(log_logistic_unit(-y) - a*log(x), log_logistic_unit(y) - b*log(x))
end
function log_dNdx(x::AbstractArray, a, b)
xbreak = one.(x)
breakscale = xbreak ./ 10
y = (x .- xbreak) ./ breakscale
return logaddexp.(log_logistic_unit(-y) .- a * log.(x), log_logistic_unit(y) .- b * log.(x))
end
function log_dNdy(y, c)
c .* log.(y)
end
Random.seed!(969886409174182839)
# try
# global xs = 10.0.*rand(128, 1024)
# global ys = rand(128, 1024)
# finally
# Random.seed!()
# end
function make_loglikelihood(xs, ys)
function logl(a, b, c)
log_dN = log_dNdx(xs, a, b) .+ log_dNdy(ys, c)
sum(logsumexp(log_dN, dims=2))
end
logl
end
function main()
xs = 10.0.*rand(128, 1024)
ys = rand(128, 1024)
logl = make_loglikelihood(xs, ys)
grad_logl(a,b,c) = gradient(logl, a, b, c)
@btime $grad_logl(0.0, 0.0, 0.0)
end |
This is somewhat worrying, how are you doing the profiling? It may be that we want to interleave some inference passes, but best to avoid that by producing the correct code while fusing in the first place. |
Wow! That's amazing---and the performance looks great. Thank you! Just to make sure I understand: you defined an array version of I agree that perhaps this points to a problem with fusion for the adjoints. I was inferring that the calls were generic dispatch by using the Juno profiler, but I think I made a mistake; I'm sorta new to Juno. In my original code the call to Juno.@profiler grad_logl(0.0, 0.0, 0.0) generates something like the attached screenshot. The width of the highlighting of the code is proportional to the runtime with that line active on the stack (also the width of the waterfall bars in the plot on the right); the color indicates whether the call is on the stack with a concrete (is that the word?; in blue) or generic dispatch (yellow), and also whether there was a garbage collection on the line (red). I think I was mis-reading the profile output (apologies!)---it looks like there were lots of GCs deep inside my code, but not generic dispatch (except to a couple of high-level functions associated with setting up the adjoints---but these presumably account for a very tiny fraction of the total runtime). Seems like a really neat tool; see http://docs.junolab.org/stable/man/juno_frontend/#Profiler-1 . |
So we aren't breaking up broadcasting, more defining the rules as array operations rather than scalar operations (notice the functions all do broadcast internally as well, as one would expect) Zygote is pretty efficient on array code, and zygote.forward on scalar code. You could optimise the code I wrote by avoiding the copies made with multiplying by 4 etc, and notice there are no adjoints that are necessary, because once you have array math in there, reverse mode ad is efficient. |
I am interested in these results. Have they been incorporated into Zygote? |
What exactly were you hoping to be incorporated? There have been a number of changes to broadcasting since this issue was posted, e.g. #1001. Running the original example on Julia 1.8.2:
|
Oh, these results look very nice. I didn't have specific requirements, I'm currently profiling some code and thought broadcasting might play a role, but since this has been fixed it is not related to my slow code. Thanks anyway for posting the update! |
I know this issue has been brought up before, but Zygote is really slow to take gradients through broadcasting, and I have attached a simple example that suggests (at least to my un-expert eye) that the issue may be due to type inference failing in the adjoint for broadcasting. Why it is failing, I have no idea, but perhaps with this simple test case and info an expert has an idea?
Basically, the attached code is a pared-down bit of a model I was trying to fit in Turing.jl. There are two data variables, x and y, which are being fit to a broken power law (for x) and a power law (for y). There are three parameters (two slopes for the broken power law, one for the plain power law). I don't think there is anything special about the functional forms here, just that they are complex enough that Zygote doesn't have any special case code for handling the broadcasted forms when they are applied across arrays. The product of the (log density of the) broken power law and the power law is evaluated on many samples stored in a 2D array and averaged along one axis and then summed over the other. This is a common pattern in "meta analyses" when samples drawn from distributions produced by several analyses of independent data sets are first averaged (approximating an integral over the uncertainty within each analysis) and then the average probabilities multiplied (accumulating the information from the multiple independent data sets).
The upshot is that the log-probability function on floats is simple enough that there is no problem for Julia to infer Float64 through the broadcast, and even with 100s of thousands of x-y samples, completes in ms:
But the gradient is 500x slower:
The full code, including environment, etc, is attached. Looking at the gradient computation using the Juno profiler, I see lots of dynamic dispatch calls to the inner computations in the log-likelihood---which is reflected in the huge amount of allocation. But it should be obvious to the complier that the adjoint of these R->R functions is also R->R in the "inner loop" of the broadcast; so somehow this is getting obscured by the generic adjoint code in Zygote.
Hopefully this can help narrow down the poor performance on broadcast and some wizard watching has a good idea for how to get around this; I do a lot of few-tens to 100 parameter model fittings in Stan right now using this sort of code that I would love to move to Turing.jl, but it's too expensive to ForwardDiff them---Zygote should be the solution, but not until I can broadcast without a three-order-of-magnitude slowdown....
BroadcastExample.tar.gz
The text was updated successfully, but these errors were encountered: