Skip to content

Commit b36e66c

Browse files
authored
Merge pull request #717 from JuliaDiff/ox/oneelement
Bring over OneElement for scalar getindex
2 parents 83592fe + 488dca6 commit b36e66c

File tree

4 files changed

+46
-13
lines changed

4 files changed

+46
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ChainRules"
22
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
3-
version = "1.50.0"
3+
version = "1.51.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/rulesets/Base/indexing.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,16 +81,48 @@ For the `rrule` of `y = x[inds...]`, this function is roughly
8181
`setindex(zero(x), dy, inds...)`, returning the array `dx`.
8282
Differentiable. Includes `ProjectTo(x)(dx)`.
8383
"""
84-
function ∇getindex(x::AbstractArray, dy, inds...)
84+
function ∇getindex(x::AbstractArray{T,N}, dy, inds...) where {T,N}
8585
# `to_indices` removes any logical indexing, colons, CartesianIndex etc,
8686
# leaving just Int / AbstractVector of Int
8787
plain_inds = Base.to_indices(x, inds)
88-
dx = _setindex_zero(x, dy, plain_inds...)
89-
∇getindex!(dx, dy, plain_inds...)
88+
dx = if plain_inds isa NTuple{N, Int} && T<:Number
89+
# scalar indexing
90+
OneElement(dy, plain_inds, axes(x))
91+
else # some from slicing (potentially noncontigous)
92+
dx = _setindex_zero(x, dy, plain_inds...)
93+
∇getindex!(dx, dy, plain_inds...)
94+
end
9095
return ProjectTo(x)(dx) # since we have x, may as well do this inside, not in rules
9196
end
9297
∇getindex(x::AbstractArray, z::AbstractZero, inds...) = z
9398

99+
"""
100+
OneElement(val, ind, axes) <: AbstractArray
101+
102+
Extremely simple `struct` used for the gradient of scalar `getindex`.
103+
"""
104+
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
105+
val::T
106+
ind::I
107+
axes::A
108+
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
109+
end
110+
Base.size(A::OneElement) = map(length, A.axes)
111+
Base.axes(A::OneElement) = A.axes
112+
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
113+
114+
function ChainRulesCore.add!!(xs::AbstractArray{<:Any,N}, oe::OneElement{<:Any,N}) where {N}
115+
if !ChainRulesCore.is_inplaceable_destination(xs)
116+
xs = collect(xs)
117+
end
118+
xs[oe.ind...] += oe.val
119+
return xs
120+
end
121+
122+
Base.:(+)(xs::AbstractArray, oe::OneElement) = add!!(copy(xs), oe)
123+
Base.:(+)(oe::OneElement, xs::AbstractArray) = +(xs, oe)
124+
Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)
125+
94126
"""
95127
_setindex_zero(x, dy, inds...)
96128

test/rulesets/Base/array.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,15 @@ end
358358
@test_skip test_frule(findmin, rand(3,4), output_tangent = (rand(), NoTangent()))
359359
@test_skip test_frule(findmin, rand(3,4), fkwargs=(dims=1,))
360360
# These skipped tests might be fixed by https://github.com/JuliaDiff/FiniteDifferences.jl/issues/188
361+
# or by https://github.com/JuliaLang/julia/pull/48404
361362

362363
# Reverse
363364
test_rrule(findmin, rand(10), output_tangent = (rand(), false))
364365
test_rrule(findmax, rand(10), output_tangent = (rand(), false))
365-
test_rrule(findmin, rand(5,3))
366-
test_rrule(findmax, rand(5,3))
367-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
368-
@test [0 0; 0 5] == @inferred unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
366+
test_rrule(findmin, rand(5,3); check_inferred=false)
367+
test_rrule(findmax, rand(5,3); check_inferred=false)
368+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, nothing))[2])
369+
@test [0 0; 0 5] == unthunk(rrule(findmax, [1 2; 3 4])[2]((5.0, NoTangent()))[2])
369370

370371
# Reverse with dims:
371372
@test [0 0; 5 6] == @inferred unthunk(rrule(findmax, [1 2; 3 4], dims=1)[2](([5 6], nothing))[2])
@@ -385,7 +386,7 @@ end
385386

386387
# Reverse
387388
test_rrule(imum, rand(10))
388-
test_rrule(imum, rand(3,4))
389+
test_rrule(imum, rand(3,4); check_inferred=false)
389390
@gpu test_rrule(imum, rand(3,4), fkwargs=(dims=1,))
390391
test_rrule(imum, rand(3,4,5), fkwargs=(dims=(1,3),))
391392

test/rulesets/Base/indexing.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434

3535
@testset "single element" begin
3636
test_rrule(getindex, x, 2)
37-
test_rrule(getindex, x, 2, 1)
38-
test_rrule(getindex, x, 2, 2)
37+
test_rrule(getindex, x, 2, 1; check_inferred=false)
38+
test_rrule(getindex, x, 2, 2; check_inferred=false)
3939

40-
test_rrule(getindex, x, CartesianIndex(2, 3))
40+
test_rrule(getindex, x, CartesianIndex(2, 3); check_inferred=false)
4141
end
4242

4343
@testset "slice/index positions" begin
@@ -87,7 +87,7 @@
8787
dgrad = rrule(getindex, Diagonal(rand(3)), 2, :)[2]([1,2,3])[2]
8888
@test unthunk(dgrad) Diagonal([0, 2, 0])
8989

90-
test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2)
90+
test_rrule(getindex, Symmetric(rand(3, 3)), 2, 2; check_inferred=false) # Infers to Any
9191
sgrad = rrule(getindex, Symmetric(rand(3, 3)), 2, 3)[2](1.0)[2]
9292
@test unthunk(sgrad) [0 0 0; 0 0 1/2; 0 1/2 0]
9393
end

0 commit comments

Comments
 (0)