-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow backward pass when the forward pass touches a large array #323
Comments
I started wonder if problem like this is unavoidable by design. For example, the adjoint of Lines 207 to 209 in d74f3cf
Wouldn't it be wasteful if (say) Maybe it is better to use struct-of-pullbacks rather than pullback-of-struct? In the above case, it would be something like @adjoint function(A::AbstractMatrix * B::AbstractMatrix)
return A * B, (Δ -> Δ * B', Δ -> A' * Δ)
end instead. It'd be a bit hard when it turns out that you need to compute derivatives w.r.t all arguments in the end and if some of the computations for them can be shared (see also JuliaDiff/ChainRulesCore.jl#8). But I think that can be solved relatively easily by sharing states between the pullbacks. Does my argument make sense? Or is there already a similar facility to not compute irrelevant intermediate derivatives? (The function |
I think the issue is to do with not representing the sparsity in the adjoint properly. We could use a 1-hot-like vector (or variant thereof where the non-zero element is allowed to have any value) for The issue regarding Separately, we have thunks to avoid actually performing certain bits of computation that we don't want. e.g. in the matrix multiplication example, we'll have pullback that when invoked returns two thunks. |
Yeah, so if you look at the adjoint for getindex, you'll see that we're literally returning with only one non-zero element in the case that the forwards pass pulls out a single element of the array. This is a prime target for optimisation, and it would be great to see a PR to sort this out, at least in the single-element-index case to start with. |
@willtebbutt Thanks a lot for clarifying the current situation around these problems (esp. what is planned around ChainRules and thunk). But it's still not clear to me if everything in the OP can be explained by What I forgot to emphasize was that julia> function bench(n)
c = ones(n)
proj = zeros(n)
proj[1] = 1
_, back = forward(p -> (c'proj) * p, 1)
@benchmark $back(1)
end
bench (generic function with 1 method)
julia> bench(10 ^ 3) |> display
bench(10 ^ 4) |> display
bench(10 ^ 5) |> display
BenchmarkTools.Trial:
memory estimate: 15.88 KiB
allocs estimate: 2
--------------
minimum time: 1.045 μs (0.00% GC)
median time: 1.712 μs (0.00% GC)
mean time: 2.718 μs (17.00% GC)
maximum time: 286.861 μs (98.12% GC)
--------------
samples: 10000
evals/sample: 10
BenchmarkTools.Trial:
memory estimate: 156.41 KiB
allocs estimate: 4
--------------
minimum time: 10.380 μs (0.00% GC)
median time: 12.935 μs (0.00% GC)
mean time: 16.034 μs (12.74% GC)
maximum time: 1.029 ms (93.46% GC)
--------------
samples: 10000
evals/sample: 1
BenchmarkTools.Trial:
memory estimate: 1.53 MiB
allocs estimate: 4
--------------
minimum time: 110.699 μs (0.00% GC)
median time: 129.600 μs (0.00% GC)
mean time: 162.228 μs (14.29% GC)
maximum time: 1.541 ms (55.01% GC)
--------------
samples: 10000
evals/sample: 1 |
Ah I see what you're saying. My immediate response is: do we care about optimising for this kind of situation? This seems like kind of a hard problem to optimise away as, from the perspective of Zygote (which only really gets to reason about what goes on inside for That said, were you to move the creation of I guess my opinion on the matter is that the asymptotic complexity of the reverse-pass should generally be the same as the forwards pass, and this follows immediately from each As regards |
I think there are other important situations where the interactions with constant variables are not sparse. For example input/output data to the neural nets and the random variables for dropouts and GANs are all constants from the point of view of auto-differentiation. They interact densely with the variables that depend on the "input" with which derivatives are calculated. Furthermore, the size of those constants are comparable to the variables that depend on the "inputs", especially, I think, in the context of "scientific AI" where are the models are not as big as those huge deep neural nets. See also @antoine-levitt's real-application example where this matters JuliaNLSolvers/NLsolve.jl#205 (comment). It looks to me that those cases need structure-of-pullbacks approach (or equivalently thunk-based approach?) to defer the computation until needed. (Alternatively, I suppose the AD engine can somehow "mark" the values that depend on the "input" during the forward pass so that adjoint function can know with which argument it has to take the derivatives. But this sounds more complicated than structure-of-pullbacks to me.) Of course, the examples I noted are very basic in machine learning setup so it is very possible that I am just ignorant about existing solutions to them. But I thought there is some non-zero chance that the design of Zygote is not fully reviewed because it is not yet the main auto-differentiation engine of Flux. |
Now that I mention GAN, another example is taking derivative with respect to the generator. It is wasteful to take derivative w.r.t discriminator parameters in this case. |
Hmm I see your point. In a tape-based system you get to know which variables are involved in AD, and which aren't, for free. In Zygote's world you don't though. One option is to use |
Does Zygote know internally which arguments of a function need to be tracked back to |
@willtebbutt In JuliaDiff/ChainRulesCore.jl#30, @oxinabox explained to me that that's what thunk is for and how the rules are implemented using it. You've already mentioned thunks but I guess I didn't really get it enough while writing the last comment. Now I'm convinced that thunks solve the issue I brought up. @mcabbott I was thinking a similar solution too. But I started to think that using thunks can solve most of the problems. Closing, since #291 will take care of this issue. |
I see, thanks for the link. Maybe I finally understand what a thunk is, simpler than I imagined. |
As per my previous comment, I'm not sure that they really do without use in conjunction with more general knowledge regarding whether or not parents (in the computational graph sense) need to compute their own adjoints i.e. via some |
So my understanding is that code like this A = rand(10, 10)
x = randn(10)
y, back = forward(x -> A * x.^2, x)
back(ones(10)) is equivalent to x1 = x.^2
back11 = Δ -> 2 .* x .* Δ
back12 = Δ -> x1 .* log.(x) .* Δ
back = back11 # derivative is taken w.r.t x
y = A * x1
back21 = Δ -> Δ * x1'
back22 = Δ -> A' * Δ
back = back ∘ back22 # derivative is taken w.r.t x1
back(ones(10)) when the thunks are used (i.e., Since the final Admittedly I've never checked where the closure |
FYI, I need an immediate solution so I created https://github.com/tkf/ChainCutters.jl which implements |
A bit more explanation: There are two ingredients for making adjoint of broadcasting fast. The first ingredient is ForwardDiff-based adjoint of broadcasting based on the code in Zygote used for CuArrays. However, since this effectively makes differentiation "eager" for all arguments, it is not an ideal approach if the arity of broadcasted function is large compared to the number of variables differentiated. So, the second ingredient is to do forward differentiation only with respect to the non-constant variables. This is only implemented for To check the efficiency of the approach I took, I ran a benchmark with three equivalent functions below with differently implementations. f_man(p, x) = function(c)
@unpack c0, c1, c3, c4, c5, c6, c7, c8, c9 = p
c2 = c
y = @. c0 + c1 * x +
c2 * x^2 +
c3 * x^3 +
c4 * x^4 +
c5 * x^5 +
c6 * x^6 +
c7 * x^7 +
c8 * x^8 +
c9 * x^9
return sum(y)
end
f_nocut(p, x) = function(c)
q = @set p.c2 = c
q :: Poly9 # FYI
sum(q.(x))
end
f_cut(p, x) = function(c)
q = cut(@set p.c2 = uncut(c))
sum(q.(cut(x)))
end
xs = rand(1000)
p = Poly9(rand(10)...)
suite["f_cut"] = @benchmarkable Zygote.gradient($(f_cut(p, xs)), 1.0)
suite["f_nocut"] = @benchmarkable Zygote.gradient($(f_nocut(p, xs)), 1.0)
suite["f_man"] = @benchmarkable Zygote.gradient($(f_man(p, xs)), 1.0) Here Here is the result: 3-element BenchmarkTools.BenchmarkGroup:
tags: []
"f_man" => BenchmarkTools.Trial:
memory estimate: 1.44 MiB
allocs estimate: 41179
--------------
minimum time: 1.466 ms (0.00% GC)
median time: 1.627 ms (0.00% GC)
mean time: 2.025 ms (10.43% GC)
maximum time: 456.652 ms (1.78% GC)
--------------
samples: 2466
evals/sample: 1
"f_nocut" => BenchmarkTools.Trial:
memory estimate: 64.36 MiB
allocs estimate: 1356032
--------------
minimum time: 343.052 ms (13.19% GC)
median time: 346.456 ms (13.60% GC)
mean time: 750.353 ms (10.50% GC)
maximum time: 3.167 s (8.84% GC)
--------------
samples: 7
evals/sample: 1
"f_cut" => BenchmarkTools.Trial:
memory estimate: 37.23 KiB
allocs estimate: 176
--------------
minimum time: 114.540 μs (0.00% GC)
median time: 121.496 μs (0.00% GC)
mean time: 152.005 μs (4.88% GC)
maximum time: 197.495 ms (4.08% GC)
--------------
samples: 10000
evals/sample: 1 As you can see, constant annotation can make differentiation 3000x faster than the code without annotation and even 12x faster than "manually expanded" code. I'd imagine Zygote.jl can automatically insert some kind of constant annotations during the forward-pass (something equivalent to |
Here is a MWE:
I see similar effect with an alternative implementation:
As you can see, computation time of
back
grows aslength(c)
grows even though majority ofc
does not participate in the computation. Is it possible to avoid this problem?The text was updated successfully, but these errors were encountered: