diff --git a/Project.toml b/Project.toml index 82f4d6616..77f514046 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.58.0" +version = "1.58.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/indexing.jl b/src/rulesets/Base/indexing.jl index 2f5e6cf79..830571ecd 100644 --- a/src/rulesets/Base/indexing.jl +++ b/src/rulesets/Base/indexing.jl @@ -144,7 +144,7 @@ end ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...) function ∇getindex!(dx::AbstractArray, dy, inds::Integer...) - view(dx, inds...) .+= Ref(dy) + @views dx[inds...] += dy return dx end function ∇getindex!(dx::AbstractArray, dy, inds...) diff --git a/test/rulesets/Base/indexing.jl b/test/rulesets/Base/indexing.jl index a677df3b9..e878dd061 100644 --- a/test/rulesets/Base/indexing.jl +++ b/test/rulesets/Base/indexing.jl @@ -177,6 +177,14 @@ end @test Array(y3) == Array(x_23_gpu)[1, [1,1,2]] @test unthunk(bk3(jl(ones(3)))[2]) == jl([2 1 0; 0 0 0]) end + + @testset "getindex(::Array{<:AbstractGPUArray})" begin + x_gpu = jl(rand(1)) + y, back = rrule(getindex, [x_gpu], 1) + @test y === x_gpu + dxs_gpu = unthunk(back(jl([1.0]))[2]) + @test dxs_gpu == [jl([1.0])] + end end # first & tail handled by getfield rules