Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
112: Simplest prod(x; dims) gradient r=dhairyagandhi96 a=mcabbott The current gradient for `prod(x; dims)` gives incorrect results, this PR fixes it (parallel to FluxML/Tracker.jl#1 ): ``` julia> using Zygote, ForwardDiff julia> r = rand(2,3,2); julia> ForwardDiff.gradient(w->sum(prod(w, dims=(2,3))), r) 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # wrong answer! 2×3×2 Array{Float64,3}: [:, :, 1] = 5.93867e-6 4.30525e-6 2.31817e-5 1.60301e-5 3.2038e-5 8.44331e-6 [:, :, 2] = 1.95925e-5 6.33622e-5 6.89391e-6 1.36795e-5 4.19746e-6 4.07989e-6 julia> Zygote.@adjoint function prod(xs; dims = :) # as in this PR p = prod(xs; dims = dims) p, Δ -> (p ./ xs .* Δ,) end julia> Zygote.refresh() julia> Zygote.gradient(w->sum(prod(w, dims=(2,3))), r)[1] # now matches ForwardDiff 2×3×2 Array{Float64,3}: [:, :, 1] = 0.00131643 0.000954347 0.0051387 0.0177437 0.0354628 0.00934587 [:, :, 2] = 0.00434307 0.0140455 0.00152818 0.0151417 0.00464615 0.00451601 ``` This does not handle zeros in the array correctly -- see FluxML/Flux.jl#524 for attempts to do that. The `circshift(...` operation deleted here was a correct (but slow) gradient for `prod(x)`, but is clearly independent of `dims`. The example above is almost the same as the one in the tests, which strangely passes, without this PR. Perhaps something is wrong with `gradtest`? ``` julia> @test gradtest(x -> prod(x, dims = (2, 3)), (3,4,5)) Test Passed julia> @test gradtest(x -> prod(x), (3,4,5)) Test Passed ``` Co-authored-by: Michael Abbott <me@pseudomac>
- Loading branch information