Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure sum(f,x) works on GPU #1004

Merged
merged 5 commits into from
Jun 21, 2021
Merged

ensure sum(f,x) works on GPU #1004

merged 5 commits into from
Jun 21, 2021

Conversation

oxinabox
Copy link
Member

Seems like there is some pains from #990 re:GPU.
In particular we broke DiffEqSensitivity
https://buildkite.com/julialang/diffeqsensitivity-dot-jl/builds/169#d254017e-e824-4d9c-854d-f3b348395599/411-877

@ChrisRackauckas 's "M"WE is

using DiffEqFlux, OrdinaryDiffEq, DiffEqSensitivity
using CUDA, Test, Zygote
CUDA.allowscalar(false)
H = CuArray(rand(Float32, 2, 2))
ann = FastChain(FastDense(1, 4, tanh))
p = initial_params(ann)
function func(x, p, t)
    ann([t],p)[1]*H*x
end
x0 = CuArray(rand(Float32, 2))
x1 = CuArray(rand(Float32, 2))
prob = ODEProblem(func, x0, (0.0f0, 1.0f0))
function evolve(p)
    solve(prob, Tsit5(), p=p, save_start=false,
          save_everystep=false, abstol=1e-4, reltol=1e-4,
          sensealg=QuadratureAdjoint(autojacvec=ZygoteVJP())).u[1]
end
function cost(p)
    x = evolve(p)
    c = sum(abs,x - x1)
    #println(c)
    c
end
grad = Zygote.gradient(cost,p)[1]
@test !iszero(grad[1])
@test iszero(grad[2:4])
@test !iszero(grad[5])
@test iszero(grad[6:end])

I am hoping we can get it to fail with just sum(f, xs) (which I have added to tests)}
I can't run GPU locally which makes testing this hard.
If I have to I will spin up an EC2 instance, but I would really rather not.

I think what is going on is, from looking at the logs

The error happens in during the forward pass.
In particular here
https://github.com/JuliaDiff/ChainRules.jl/blob/52a0eeadf8d19bff491f224517b7b064ce1ba378/src/rulesets/Base/mapreduce.jl#L46
I think this was why Zygote implemented
the pullback of sum(f, x) as sum(f.(x)) (which is slower and more allocate-y than our never version)
so that it could hit the code that Zygote has special for CUDA that does forwards-mode.
(Which means it doesn't need the Context object containing the IdDict)
So I think the solution in short-term is probably to add the old rule for sum back in (but for CuArray only) here.

@eval @adjoint function broadcasted(::CuArrayStyle, f, args...)

# Make sure sum(f, ::CuArray) uses forward mode broadcast AD defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU safe
@adjoint function sum(f, xs::CuArray; kws...)
  @assert !haskey(kws, :init) # TODO add init support (julia 1.6)
  return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

In the longer-term, we will probably default to doing the f from sum(f, xs) in forward-mode anyway.
So Zygote's rule config can be updated to say that it does use ForwardDiff.jl for it's frule_via_ad.

@oxinabox oxinabox changed the title ensure sum(f,x) is broken on GPU ensure sum(f,x) is works on GPU Jun 20, 2021
src/lib/broadcast.jl Outdated Show resolved Hide resolved
test/cuda.jl Outdated Show resolved Hide resolved
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.CuArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You probably don't need the context manager here

Copy link
Member Author

@oxinabox oxinabox Jun 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exactly the old code.
I don't really know what it is doing returning the context<

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think keeping it is better.
It should end up hitting the same code either way, but I can imagine if some specialization are added then the errors we would get are more confusing if we don't include the __context__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be slightly tidier to call broadcast_forward directly, instead of going back to pullback. But this is a quick fix to repair an accidental break so the most direct restoration is fine too.

@ChrisRackauckas
Copy link
Member

using Zygote, CUDA
a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a_gpu = a |> cu

f(x) = sum(abs, x)
g = gradient(f, a)[1]
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect  g

Does recreate the bug on release.

@mcabbott
Copy link
Member

mcabbott commented Jun 21, 2021

Instead of going via broadcasting, I think the ideal thing here would be something like this:

using Zygote
using ForwardDiff: Dual, partials
_sum(f,A;kw...) = sum(f,A;kw...)

Zygote.@adjoint function _sum(f, A::AbstractArray; kw...)
    sum(f, A; kw...), dY -> (nothing, broadcast(dY,A) do dy, a
        dY * partials(f(Dual(a,true)),1)
    end)
end

I haven't tested this beyond this, but it does seem to work, and cut memory used by a factor of 4:

julia> gradient(x -> sum(sqrt, x), [1,2,3])
([0.5, 0.35355339059327373, 0.2886751345948129],)

julia> gradient(x -> _sum(sqrt, x), [1,2,3])
([0.5, 0.35355339059327373, 0.2886751345948129],)

In fact, we should do this on the CPU too, for arrays of real numbers. Once we have tested that f is sufficiently pure?

If f closes over other parameters, then this won't work. Ideally the CUDA case would thrown an error then, instead of ignoring the problem. How often the CPU case makes sense, given that sum tries to do pairwise tricks and so definitely doesn't claim to execute f(a) in any particular order, I don't know.

But this PR shouldn't wait for such things.

Edit --- some times:

julia> @btime gradient(x -> sum(sqrt, x), $(rand(400,400)));
  469.167 μs (32 allocations: 8.55 MiB)     # Zygote v0.6.13
  7.703 ms (480099 allocations: 14.65 MiB)  # with ChainRules v0.8.11 i.e. before https://github.com/JuliaDiff/ChainRules.jl/pull/441

julia> @btime gradient(x -> _sum(sqrt, x), $(rand(400,400)));
  324.500 μs (2 allocations: 1.22 MiB)

julia> @btime gradient(x -> sum(log, x), $(rand(400,400)));
  1.288 ms (32 allocations: 8.55 MiB)       # Zygote v0.6.13
  8.312 ms (480099 allocations: 14.65 MiB)  # with ChainRules v0.8.11

julia> @btime gradient(x -> _sum(log, x), $(rand(400,400)));  # evaluates log twice, not optimised out?
  1.719 ms (2 allocations: 1.22 MiB)

julia> @btime gradient(x -> sum(inv∘sqrt, x), $(rand(400,400)));  # harder to infer
  509.521 ms (5760041 allocations: 137.94 MiB)   # Zygote v0.6.13
  510.684 ms (5440152 allocations: 114.75 MiB)  # with ChainRules v0.8.11

julia> @btime gradient(x -> _sum(inv∘sqrt, x), $(rand(400,400)));
  398.084 μs (2 allocations: 1.22 MiB)

julia> @btime copy($(rand(400,400)));  # to count copies above
  39.958 μs (2 allocations: 1.22 MiB)

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps merge #1003 before tagging, too? Done.

I re-started the tests, but it seems they are run on the branch, not on top of #1003, so still fail on Julia nightly as before.

@mcabbott mcabbott changed the title ensure sum(f,x) is works on GPU ensure sum(f,x) works on GPU Jun 21, 2021
@DhairyaLGandhi
Copy link
Member

bors r+

@bors
Copy link
Contributor

bors bot commented Jun 21, 2021

Build succeeded:

@bors bors bot merged commit 18a6f2a into master Jun 21, 2021
@bors bors bot deleted the ox/gpupain branch June 21, 2021 18:47
@oxinabox
Copy link
Member Author

oxinabox commented Jun 21, 2021

In fact, we should do this on the CPU too, for arrays of real numbers. Once we have tested that f is sufficiently pure?

Yes.
This is a poster child for writing a rule that is RuleConfig{>:HasForwardMode}.
and for teaching Zygote that it's frule_via_ad is to use ForwardDiff.
A lot of the code need for that is in https://github.com/oschulz/ForwardDiffPullbacks.jl
I think

@mcabbott
Copy link
Member

mcabbott commented Jun 21, 2021

Oh nice, I didn't know about ForwardDiffPullbacks.jl, that looks like the high-tech version. Is this going to be hooked up soon?

I was picturing a 2-line Zygote-style solution. Possibly with or after #1001.

(In fact what #1001 does should perhaps ultimately be routed to more high-tech stuff too...)

For my log example above, I think all of these would rely on the the compiler not to evaluate log twice, since that part of the Dual output is not used (and log is much slower than inv). An ideal forward mode story here would, I presume, build this in, and evaluate only the tangent.

@oxinabox
Copy link
Member Author

Oh nice, I didn't know about ForwardDiffPullbacks.jl, that looks like the high-tech version. Is this going to be hooked up soon?

Not by me.
Not soon at least.
My goal is to make sure everything is possible.
Making things happen, beyond what is needed to prove that they indeed possible will come later.
(All my plans have a horizon of JuliaCon/ChainRulesCore 1.0.
Will make new plans after that, and a holiday)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CUDA All things GPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants