diff --git a/src/lib/array.jl b/src/lib/array.jl index e09e99bc..7b7456d7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -300,15 +300,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) @grad sum(xs; dims = :) = sum(data(xs), dims = dims), Δ -> (zero(xs) .+ Δ, ) -Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) -Base.prod(xs::TrackedArray) = track(prod, xs) +Base.prod(xs::TrackedArray; dims=:) = track(prod, xs; dims=dims) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) -@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) -@grad prod(xs, dim) = prod(data(xs), dims = dim), - Δ -> (nobacksies(:sum, - reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), - nothing) +@grad function prod(xs; dims=:) + p = prod(data(xs); dims=dims) + p, Δ -> (p ./ xs .* Δ,) +end Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)