Skip to content

Commit ab4b3f9

Browse files
committed
bang
1 parent a2cd256 commit ab4b3f9

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/rulesets/Base/mapreduce.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,16 @@ function frule((_, xdot), ::typeof(cumsum), x::AbstractArray; dims::Integer)
162162
end
163163
frule(tang, ::typeof(cumsum), x::AbstractVector) = frule(tang, cumsum, x; dims=1)
164164

165+
function frule((_, ydot, xdot), ::typeof(cumsum!), y::AbstractArray, x::AbstractArray; dims::Integer)
166+
return cumsum!(y, x; dims=dims), cumsum!(ydot, xdot; dims=dims)
167+
end
168+
frule(t, ::typeof(cumsum!), y::AbstractVector, x::AbstractVector) = frule(t, cumsum!, y, x; dims=1)
169+
165170
function rrule(::typeof(cumsum), x::AbstractArray; dims::Integer)
166171
project = ProjectTo(x)
167172
function cumsum_pullback(dy)
168173
step1 = reverse(unthunk(dy); dims=dims)
169-
if ChainRulesCore.is_inplaceable_destination(step1)
174+
if ChainRulesCore.is_inplaceable_destination(step1) && VERSION >= v"1.6"
170175
step2 = cumsum!(step1, step1; dims=dims)
171176
step3 = reverse!(step2; dims=dims)
172177
else

test/rulesets/Base/mapreduce.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,13 @@ end
240240
@testset "Accumulations" begin
241241
@testset "cumsum" begin
242242
v = round.(10 .* randn(9), sigdigits=3)
243-
m = round.(10 .* randn(4,5), sigdigits=3)
243+
m = round.(10 .* randn(4, 5), sigdigits=3)
244244

245245
# Forward
246246
test_frule(cumsum, v)
247247
test_frule(cumsum, m; fkwargs=(;dims=1))
248+
test_frule(cumsum!, rand(9), v)
249+
test_frule(cumsum!, rand(4, 5), m; fkwargs=(;dims=1))
248250

249251
# Reverse
250252
test_rrule(cumsum, v)

0 commit comments

Comments
 (0)