Skip to content

Commit

Permalink
Merge #284
Browse files Browse the repository at this point in the history
284: adjoint for cumsum r=CarloLucibello a=mcabbott

The easy half of #282

Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <me@pseudomac>
  • Loading branch information
3 people authored Feb 27, 2020
2 parents 70fb6b2 + 607322e commit 3b015db
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,16 @@ _backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean)
return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
end

@adjoint function cumsum(xs::AbstractVector; dims::Integer = 1)
dims == 1 || return copy(xs), Δ -> (Δ,)
cumsum(xs), Δ -> (reverse(cumsum(reverse(Δ))),)
end
@adjoint function cumsum(xs::AbstractArray; dims::Integer)
dims <= ndims(xs) || return copy(xs), Δ -> (Δ,)
cumsum(xs; dims=dims), Δ -> begin
(reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims),)
end
end

# LinAlg
# ======
Expand Down
6 changes: 6 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ end
@test gradtest(x -> prod(x), (3,4))
@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2)

@test gradtest(x -> cumsum(x, dims=2), (3,4,5))
@test gradtest(x -> cumsum(x, dims=1), (3,))
@test gradtest(x -> cumsum(x), (4,))
@test gradtest(x -> cumsum(x, dims=3), (5,)) # trivial
@test gradtest(x -> cumsum(x, dims=3), (3,4)) # trivial

@test gradtest(x -> softmax(x).*(1:3), 3)
@test gradtest(x -> softmax(x).*(1:3), (3,5))
@test gradtest(x -> softmax(x, dims=2).*(1:3), (3,5))
Expand Down

0 comments on commit 3b015db

Please sign in to comment.