-
-
Notifications
You must be signed in to change notification settings - Fork 607
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradients for prod, cumsum, cumprod #524
Conversation
Thanks a lot for the patch here. Some of these seem a bit strange to me though, e.g. taking a jacobian through forward diff seems like it'd be slower than just tracing operations with reverse mode (and a lot of this definitely won't be GPU compatible, which isn't a deal breaker but worth considering). It might be worth splitting some of this out into smaller pieces and doing one thing at a time; this would make it easier to review the pieces. |
Thanks for taking a look. Let's just think about When nothing is zero, then My computer with a GPU isn't answering the phone right now, but I think my idea was to handle the case of no zeros at least, and if Would "tracing operations with reverse mode" mean using Right now, using julia> using Flux, ForwardDiff
julia> r = rand(10,100);
julia> prod(param(r), dims=1)
1×100 Array{Flux.Tracker.TrackedReal{Float64},2}:
6.49254e-7 3.12325e-5 4.41833e-9 0.000188122 … 6.07351e-6 1.42578e-5 0.00236577
julia> @btime Flux.gradient(x -> sum(sin, prod(x, dims=1)), $r)
1.413 ms (15013 allocations: 8.16 MiB) and with PR: julia> @btime Flux.gradient(x -> sum(sin, prod(x, dims=1)), $r)
8.160 μs (51 allocations: 29.41 KiB) The other possible generic fallback is the Finally, I haven't thought about second derivatives, perhaps I should. |
No problem, looking forward to getting some of this stuff in. I suggest just doing the simplest gradient for now even if it's not so good with zeros. Then we can post test cases for anything that still isn't ideal, and discuss what would make a good solution for that case. |
OK, if you like I'll make a few-lines PR which just uses I haven't made any progress, but have a question: How do I explicitly tell Flux to use |
You shouldn't generally need to, if there's no gradient it'll just fall back to the scalar version. But you can also use |
I took another look at this story, and collected the steps (and benchmarks) here: https://gist.github.com/mcabbott/ecb9a7756c0530e8fae0ef444761ffcd I would quite like I don't however see an elegant way to do that for |
Will not treat zeros correctly, see FluxML/Flux.jl#524
I thought of a more generic way to compute the
This is only 10x slower than directly indexing, instead of 200x. I've added this to the end of the above-linked gist. It's not done yet, partly because |
@mcabbott Is it true that function ∇prod(x, p::Real=prod(x), Δ=1)
if !iszero(p)
∇ = p ./ x .* Δ
elseif count(iszero, x) > 1
∇ = zero(x)
else
∇ = ∇prod_one(x, Δ)
end
end |
I ran into a subtle bug with that (with The current version returns |
good catch! (the floating point rounding) function ∇prod(x, p::Real=prod(x), Δ=1)
!iszero(p) && return ∇ = p ./ x .* Δ
numzero = count(iszero, x)
if numzero == 0
∇ = p ./ x .* Δ
elseif numzero > 1
∇ = zero(x)
else
∇ = ∇prod_one(x, Δ)
end
end |
I don't have all the numbers in my head, but when I timed things I think This whole PR seems to be about trade-offs between complication and speed -- the fastest variant involved writing my own Another concern worth some thought is whether this is can be made correct for second derivatives. Right now the PR has a Also, somehow I must never have checked this, but
So I don't really see a way to do the right thing or CuArrays. |
Can this be closed in favour of FluxML/Tracker.jl#1? Seems like that doesn't have everything here. |
I guess this is a discussion thread now, not aiming to be merged. The tracker PR was the simplest For |
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to FluxML/Tracker.jl#1 ): ``` julia> using Zygote, ForwardDiff julia> r = rand(2,3,2); julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r) 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer! 2×3×2 Array{Float64,3}: [:, :, 1] = 5.93867e-6 4.30525e-6 2.31817e-5 1.60301e-5 3.2038e-5 8.44331e-6 [:, :, 2] = 1.95925e-5 6.33622e-5 6.89391e-6 1.36795e-5 4.19746e-6 4.07989e-6 julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR p = prod(xs; dims = dims) p, Δ -> (p ./ xs .* Δ,) end julia> Zygote.refresh() julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 ``` This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`? ``` julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) Test Passed julia> @test gradtest(x -> prod(x), (3,4,5)) Test Passed ``` Co-authored-by: Michael Abbott <me@pseudomac>
Will not treat zeros correctly, see FluxML/Flux.jl#524
If there are any missing gradients, they should be added to ChainRules |
This is a new version of #334, aiming to solve these problems:
prod(xs)
has a gradient definedprod(xs) ./ xs .* Δ
which fails when some x are zero.prod(x,dims=1)
,cumprod(x)
andcumsum(x)
are allArray{<:TrackedReal}
.prod(x,1)
, in Julia 0.6 notation, has a custom gradient with various problems.This custom gradient defined using
circshift
is I believe correct for the caseprod(x)
. In #334 I wrote a variant for the caseprod(x,dims=1)
which depends ondims
. However both are very slow, and crash julia if called on arrays with 1000s of elements.Compared to previous attempt:
∇prod
treats correctly cases where the product is zero due to rounding, but nox[i]
was zero. It is much faster on the case where exactly onex[i]
is zero, with thanks to @Drvi for an idea.mapslices
, which is a bit slow but will eventually improve.∇cumprod
needs to be applied with variadic mapslices, here implemented with some ntuple-ing. I fixed stupid a bug in this, so now all test pass, and made it more compact.ForwardDiff.gradient
instead of circshift funcitons.ForwardDiff does work on CuArrays, but slowly via CPU I think, for now. However this is only called by
∇prod
when exactly one entry is zero, so other cases should be fast. (Ideally∇prod_one
and∇cumprod
could be written with CUDAnative, perahps, but I'm not sure where they would live.) The gradient forcumsum
also goes via the CPU I think, and becausereverse(cu(rand(2,2)), dims=1)
is an error; I wrote a mapslices thing for that case.Variadic mapslices JuliaLang/julia#23306 would replace my
∇cumprod_d
function. You could also write this using Julia 1.1'seachslice()
JuliaLang/julia#29749 ascumprod
only works fordims::Int
. Butprod
acceptsdim::Tuple
thus no help.There are a few benchmarks etc in this gist.
Flux's tests fail locally for me, but in
curnn.jl
unrelated to this.I just saw that cumsum, UpperTriangular, LowerTriangular operations #388 also defines a gradient for
cumsum
, however I believe it is missing areverse
.