-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ensure sum(f,x)
works on GPU
#1004
Conversation
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible | ||
@adjoint function sum(f, xs::CUDA.CuArray; kws...) | ||
@assert !haskey(kws, :init) # TODO add init support (julia 1.6) | ||
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably don't need the context manager here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is exactly the old code.
I don't really know what it is doing returning the context<
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think keeping it is better.
It should end up hitting the same code either way, but I can imagine if some specialization are added then the errors we would get are more confusing if we don't include the __context__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be slightly tidier to call broadcast_forward
directly, instead of going back to pullback. But this is a quick fix to repair an accidental break so the most direct restoration is fine too.
using Zygote, CUDA
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a_gpu = a |> cu
f(x) = sum(abs, x)
g = gradient(f, a)[1]
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g Does recreate the bug on release. |
Instead of going via broadcasting, I think the ideal thing here would be something like this: using Zygote
using ForwardDiff: Dual, partials
_sum(f,A;kw...) = sum(f,A;kw...)
Zygote.@adjoint function _sum(f, A::AbstractArray; kw...)
sum(f, A; kw...), dY -> (nothing, broadcast(dY,A) do dy, a
dY * partials(f(Dual(a,true)),1)
end)
end I haven't tested this beyond this, but it does seem to work, and cut memory used by a factor of 4:
In fact, we should do this on the CPU too, for arrays of real numbers. Once we have tested that If But this PR shouldn't wait for such things. Edit --- some times:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sum(f,x)
works on GPU
bors r+ |
Build succeeded: |
Yes. |
Oh nice, I didn't know about ForwardDiffPullbacks.jl, that looks like the high-tech version. Is this going to be hooked up soon? I was picturing a 2-line Zygote-style solution. Possibly with or after #1001. (In fact what #1001 does should perhaps ultimately be routed to more high-tech stuff too...) For my |
Not by me. |
Seems like there is some pains from #990 re:GPU.
In particular we broke DiffEqSensitivity
https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877
@ChrisRackauckas 's "M"WE is
I am hoping we can get it to fail with just
sum(f, xs)
(which I have added to tests)}I can't run GPU locally which makes testing this hard.
If I have to I will spin up an EC2 instance, but I would really rather not.
I think what is going on is, from looking at the logs
The error happens in during the forward pass.
In particular here
https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46
I think this was why Zygote implemented
the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version)
so that it could hit the code that Zygote has special for CUDA that does forwards-mode.
(Which means it doesn't need the Context object containing the IdDict)
So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here.
Zygote.jl/src/lib/broadcast.jl
Line 244 in 531da8b
In the longer-term, we will probably default to doing the f from sum(f, xs) in forward-mode anyway.
So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad.