Skip to content

Commit 86be30c

Browse files
committed
correct and test getindex rrule
1 parent 85dd9d9 commit 86be30c

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/rulesets/Base/array.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,16 @@ end
105105
##### getindex
106106
#####
107107

108-
function rrule(::typeof(getindex), x::Array{<:Number}, inds::Union{Int, Vararg{Int}})
108+
function rrule(::typeof(getindex), x::Array{<:Number}, inds::Vararg{Int})
109109
y = getindex(x, inds...)
110110
function getindex_pullback(ȳ)
111111
function getindex_add!(Δ)
112-
Δ[inds...] .+=;
112+
Δ[inds...] = Δ[inds...] .+
113113
return Δ
114114
end
115115

116116
= InplaceableThunk(
117-
@thunk(getindex_add!(zeros(x))),
117+
@thunk(getindex_add!(zero(x))),
118118
getindex_add!
119119
)
120120
return (NO_FIELDS, x̄, (DoesNotExist() for _ in inds)...)

test/rulesets/Base/array.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878
(ds, dv, dd) = pullback(ones(4))
7979
@test ds === NO_FIELDS
8080
@test dd isa DoesNotExist
81-
@test extern(dv) == 4
81+
@test extern(dv) == 4
8282

8383
y, pullback = rrule(fill, 2.0, (3, 3, 3))
8484
@test y == fill(2.0, (3, 3, 3))
@@ -87,3 +87,17 @@ end
8787
@test dd isa DoesNotExist
8888
@test dv 27.0
8989
end
90+
91+
@testset "getindex" begin
92+
x = [1.0 2.0 3.0; 10.0 20.0 30.0]
93+
ind = [2,3]
94+
= 7.2
95+
x̄_fd, = j′vp(ChainRulesTestUtils._fdm, a->getindex(a, ind...), ȳ, x)
96+
y, pullback = rrule(getindex, x, ind...)
97+
_, x̄_ad, = pullback(ȳ)
98+
99+
@test unthunk(x̄_ad) x̄_fd
100+
101+
x_like = x .+ 1.0
102+
@test x̄_ad.add!(copy(x_like)) x_like + x̄_fd
103+
end

0 commit comments

Comments
 (0)