Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradients for prod, cumsum, cumprod #524

Closed
wants to merge 1 commit into from
Closed

Conversation

mcabbott
Copy link
Member

This is a new version of #334, aiming to solve these problems:

  • prod(xs) has a gradient defined prod(xs) ./ xs .* Δ which fails when some x are zero.
  • prod(x,dims=1), cumprod(x) and cumsum(x) are all Array{<:TrackedReal}.
  • prod(x,1), in Julia 0.6 notation, has a custom gradient with various problems.

This custom gradient defined using circshift is I believe correct for the case prod(x). In #334 I wrote a variant for the case prod(x,dims=1) which depends on dims. However both are very slow, and crash julia if called on arrays with 1000s of elements.

Compared to previous attempt:

  • The function ∇prod treats correctly cases where the product is zero due to rounding, but no x[i] was zero. It is much faster on the case where exactly one x[i] is zero, with thanks to @Drvi for an idea.
  • This is now applied using the built-in mapslices, which is a bit slow but will eventually improve.
  • ∇cumprod needs to be applied with variadic mapslices, here implemented with some ntuple-ing. I fixed stupid a bug in this, so now all test pass, and made it more compact.
  • For more generic arrays, the fall-back is to call ForwardDiff.gradient instead of circshift funcitons.

ForwardDiff does work on CuArrays, but slowly via CPU I think, for now. However this is only called by ∇prod when exactly one entry is zero, so other cases should be fast. (Ideally ∇prod_one and ∇cumprod could be written with CUDAnative, perahps, but I'm not sure where they would live.) The gradient for cumsum also goes via the CPU I think, and because reverse(cu(rand(2,2)), dims=1) is an error; I wrote a mapslices thing for that case.

Variadic mapslices JuliaLang/julia#23306 would replace my ∇cumprod_d function. You could also write this using Julia 1.1's eachslice() JuliaLang/julia#29749 as cumprod only works for dims::Int. But prod accepts dim::Tuple thus no help.

@MikeInnes
Copy link
Member

Thanks a lot for the patch here. Some of these seem a bit strange to me though, e.g. taking a jacobian through forward diff seems like it'd be slower than just tracing operations with reverse mode (and a lot of this definitely won't be GPU compatible, which isn't a deal breaker but worth considering).

It might be worth splitting some of this out into smaller pieces and doing one thing at a time; this would make it easier to review the pieces.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 6, 2019

Thanks for taking a look. Let's just think about prod to start -- I'm sorry about the mess, it would indeed be nice to simplify.

When nothing is zero, then Δ .* prod(xs, dims=1) ./ xs would be ideal, fast and generic. If there are many zeros (along the product direction) then there is little point in attempting gradient descent. But, in between, I have cases where each x ∈ [0,1] and some columns will contain a zero. I wrote a function for the case of exactly one zero, which mapslices applies to each column, if necessary.

My computer with a GPU isn't answering the phone right now, but I think my idea was to handle the case of no zeros at least, and if CuArrays.allowscalar(true) have a fall-back in case you randomly hit one.

Would "tracing operations with reverse mode" mean using TrackedReal? Perhaps this could be used as a fall-back option (i.e. when you hit a zero), if it will work on the GPU. I'm not sure I tried that.

Right now, using TrackedReal for everything seems very slow:

julia> using Flux, ForwardDiff
julia> r = rand(10,100);

julia> prod(param(r), dims=1)
1×100 Array{Flux.Tracker.TrackedReal{Float64},2}:
 6.49254e-7  3.12325e-5  4.41833e-9  0.000188122    6.07351e-6  1.42578e-5  0.00236577

julia> @btime Flux.gradient(x -> sum(sin, prod(x, dims=1)), $r)
  1.413 ms (15013 allocations: 8.16 MiB)

and with PR:

julia> @btime Flux.gradient(x -> sum(sin, prod(x, dims=1)), $r)
  8.160 μs (51 allocations: 29.41 KiB)

The other possible generic fallback is the circshift-based thing. There is one in the current source which I believe is correct for prod(x) no dims. I had problems with it crashing above some size, about length = 400. And it's slow -- for prod(rand(50)), circshift takes 2ms, TrackedReal 12 μs, ForwardDiff 1.5 μs, and simple division 80ns.

Finally, I haven't thought about second derivatives, perhaps I should.

@MikeInnes
Copy link
Member

No problem, looking forward to getting some of this stuff in.

I suggest just doing the simplest gradient for now even if it's not so good with zeros. Then we can post test cases for anything that still isn't ideal, and discuss what would make a good solution for that case.

@mcabbott
Copy link
Member Author

mcabbott commented Feb 8, 2019

OK, if you like I'll make a few-lines PR which just uses Δ .* prod ./ xs for everything, as a start.

I haven't made any progress, but have a question: How do I explicitly tell Flux to use TrackedReal for something? (To try using this within the function which gets mapslices-ed.)

@MikeInnes
Copy link
Member

You shouldn't generally need to, if there's no gradient it'll just fall back to the scalar version. But you can also use map(identity, xs) to get an array of tracked reals.

@mcabbott
Copy link
Member Author

I took another look at this story, and collected the steps (and benchmarks) here:

https://gist.github.com/mcabbott/ecb9a7756c0530e8fae0ef444761ffcd

I would quite like prod(xs; dims=2) to handle xs containing zeros correctly, but if we want this, then I still don't see a simpler approach than this mapslices(∇prod,...) thing. This case is necessarily slower, but by checking for zeros first, the case of all-nonzero xs can be almost as fast as the naiive Δ .* prod ./ xs gradient.

I don't however see an elegant way to do that for CuArrays. But what occurs to me today is that perhaps it would be OK to give up on handling zeros correctly there, and just dispatch to Δ .* prod ./ xs. This won't silently give you wrong answers, you should get NaN to warn you. Might that be acceptable?

mcabbott added a commit to mcabbott/Tracker.jl that referenced this pull request Mar 10, 2019
Will not treat zeros correctly, see FluxML/Flux.jl#524
@mcabbott
Copy link
Member Author

I thought of a more generic way to compute the prod gradient, allowing zeros. Instead of calling circshift, you can create something similar by reshaping x .* ones' to have length(x)-1 rows... the simplest version looks like this:

function ∇prod_one(x, Δ)
  n = length(x) - 1
  m = reshape(vec(x) .* trues(n)' .* Δ, (n,:))
  v = reverse(vec(prod(m, dims=1)))
  reshape(v, size(x))
end

This is only 10x slower than directly indexing, instead of 200x. I've added this to the end of the above-linked gist. It's not done yet, partly because reverse doesn't seem to exist for CuArrays. But that's the news.

@jburroni
Copy link

jburroni commented Mar 19, 2019

@mcabbott Is it true that p::Real in ∇prod? If that is the case, you could write this:

function ∇prod(x, p::Real=prod(x), Δ=1)
  if !iszero(p)
    ∇ = p ./ x .* Δ
  elseif count(iszero, x) > 1= zero(x)
  else= ∇prod_one(x, Δ)
  end
end

@mcabbott
Copy link
Member Author

I ran into a subtle bug with that (with !iszero(p) as the first test): if x contains several very small numbers, than the product can be zero without any individual zeros, due to floating-point rounding. And then findfirst in the PR's ∇prod_one returns nothing and it fails.

The current version returns zero(x) in this case. Maybe ideally one could treat it better. The new idea for ∇prod_one with reshape(..., (length(x)-1,:)) may be better, in fact.

@jburroni
Copy link

good catch! (the floating point rounding)
I do still believe that trying to short-circuit the --presumible-- most common case of a product different than zero is important.

function ∇prod(x, p::Real=prod(x), Δ=1)
  !iszero(p) && return= p ./ x .* Δ
  numzero = count(iszero, x) 
  if numzero == 0= p ./ x .* Δ
  elseif numzero > 1= zero(x)
  else= ∇prod_one(x, Δ)
  end
end

@mcabbott
Copy link
Member Author

mcabbott commented Mar 19, 2019

I don't have all the numbers in my head, but when I timed things I think count(iszero, x) turned out not to matter much, a few percent of the quickest gradient. But I could be wrong. Note that you could combine the second case and third cases here to numzero != 1, as p==0 implies that p ./ x .* Δ == zero(x).

This whole PR seems to be about trade-offs between complication and speed -- the fastest variant involved writing my own mapslices (worth a factor of 2) but that started to sound like too much complication.

Another concern worth some thought is whether this is can be made correct for second derivatives. Right now the PR has a nobacksies which explicitly prevents this, and the logic of the 3-option ∇prod is I think only for first derivatives. But p ./ x .* Δ should be correct (when there are no zeros) and this new idea reshape(..., (length(x)-1,:)) is perhaps also correct?

Also, somehow I must never have checked this, but mapslices also turns out not to exist for CuArrays:

julia> CuArrays.allowscalar(false)
julia> mapslices(sum, rand(2,3) |> cu, dims=1)
ERROR: scalar setindex! is disallowed

So I don't really see a way to do the right thing or CuArrays. eachslice works but not for dims=(2,3). And since CuArrays might not be loaded, you can't even directly dispatch on its type. Perhaps just accept that the case of a CuArray with containing zeros is going to use scalar indexing, and will be very slow, until CuArrays learns to understand mapslices and reverse.

@MikeInnes
Copy link
Member

Can this be closed in favour of FluxML/Tracker.jl#1? Seems like that doesn't have everything here.

@mcabbott
Copy link
Member Author

mcabbott commented Apr 5, 2019

I guess this is a discussion thread now, not aiming to be merged.

The tracker PR was the simplest prod case as suggested, I can pull out cumsum equally simply.

For prod the tl;dr version is that I still think we ought to treat zero entries correctly. I don't see a great way to do this that includes CuArrays; on CPU there are several options (trading speed vs complication) which I can tidy up if that would help.

bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Feb 26, 2020
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott

The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to  FluxML/Tracker.jl#1 ):
```
julia> using Zygote, ForwardDiff

julia> r = rand(2,3,2);

julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r)
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer!
2×3×2 Array{Float64,3}:
[:, :, 1] =
 5.93867e-6  4.30525e-6  2.31817e-5
 1.60301e-5  3.2038e-5   8.44331e-6

[:, :, 2] =
 1.95925e-5  6.33622e-5  6.89391e-6
 1.36795e-5  4.19746e-6  4.07989e-6

julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR
         p = prod(xs; dims = dims)
         p, Δ -> (p ./ xs .* Δ,)
       end

julia> Zygote.refresh()

julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff
2×3×2 Array{Float64,3}:
[:, :, 1] =
 0.00131643  0.000954347  0.0051387 
 0.0177437   0.0354628    0.00934587

[:, :, 2] =
 0.00434307  0.0140455   0.00152818
 0.0151417   0.00464615  0.00451601
```
This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. 

The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`?
```
julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5))
Test Passed

julia> @test gradtest(x -> prod(x), (3,4,5))
Test Passed
```

Co-authored-by: Michael Abbott <me@pseudomac>
mcabbott added a commit to mcabbott/Tracker.jl that referenced this pull request Aug 10, 2020
Will not treat zeros correctly, see FluxML/Flux.jl#524
@CarloLucibello
Copy link
Member

If there are any missing gradients, they should be added to ChainRules

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants