Skip to content

Commit

Permalink
Fix getindex (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jun 5, 2022
1 parent 553ea01 commit 4a50321
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 63 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ReverseDiff"
uuid = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
version = "1.14.0"
version = "1.14.1"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
29 changes: 14 additions & 15 deletions src/tracked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,6 @@ Base.promote_rule(::Type{TrackedReal{V1,D1,O1}}, ::Type{TrackedReal{V2,D2,O2}})
# AbstractArray Interface #
###########################

Base.getindex(t::TrackedArray, i::Int) = TrackedReal(value(t)[i], deriv(t)[i], tape(t), i, t)

colon2range(s, i) = i
colon2range(s, ::Colon) = s

Expand All @@ -296,10 +294,10 @@ function index_iterable(shape::NTuple{N,Any}, i::NTuple{M,Any}) where {N,M}
end

for T in (:AbstractRange, :Colon, :(Union{Colon,AbstractRange}))
@eval function Base.getindex(t::TrackedArray, i::$(T)...)
@eval Base.@propagate_inbounds function Base.getindex(t::TrackedArray, i1::$(T), is::$(T)...)
tp = tape(t)
out = TrackedArray(value(t)[i...], deriv(t)[i...], tp)
idx = index_iterable(axes(t), i)
out = TrackedArray(value(t)[i1, is...], deriv(t)[i1, is...], tp)
idx = index_iterable(axes(t), (i1, is...))
record!(tp, SpecialInstruction, getindex, (t, idx), out)
return out
end
Expand Down Expand Up @@ -329,24 +327,25 @@ end
return nothing
end

function Base.getindex(t::TrackedArray, inds::AbstractArray{<:CartesianIndex})
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, inds::AbstractArray{<:CartesianIndex})
tp = tape(t)
out = TrackedArray(value(t)[inds], deriv(t)[inds], tp)
record!(tp, SpecialInstruction, getindex, (t, inds), out)
return out
end
function Base.getindex(t::TrackedArray, i::Int...)
ind = LinearIndices(t)[i...]
return TrackedReal(value(t)[i...], deriv(t)[i...], tape(t), ind, t)
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, i1::Integer, is::Integer...)
ind = LinearIndices(t)[i1, is...]
return TrackedReal(value(t)[i1, is...], deriv(t)[i1, is...], tape(t), ind, t)
end
function Base.getindex(t::TrackedArray, _inds::Union{Integer, Colon, AbstractArray{<:Integer}}...)
inds = ntuple(Val(length(_inds))) do i
_inds[i] isa Colon && return firstindex(t,i):lastindex(t,i)
return _inds[i]
Base.@propagate_inbounds function Base.getindex(t::TrackedArray, _inds1::Union{Integer, Colon, AbstractArray{<:Integer}}, _inds2::Union{Integer, Colon, AbstractArray{<:Integer}}...)
inds1 = _inds1 isa Colon ? axes(t, 1) : _inds1
inds2 = ntuple(Val(length(_inds2))) do i
_inds2[i] isa Colon && return axes(t, i+1)
return _inds2[i]
end
tp = tape(t)
out = TrackedArray(value(t)[inds...], deriv(t)[inds...], tp)
record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, inds), out)
out = TrackedArray(value(t)[inds1, inds2...], deriv(t)[inds1, inds2...], tp)
record!(tp, SpecialInstruction, (getindex, Val(:generic)), (t, (inds1, inds2...)), out)
return out
end
@noinline function special_reverse_exec!(instruction::SpecialInstruction{<:Tuple{typeof(getindex), Val{:generic}}})
Expand Down
104 changes: 57 additions & 47 deletions test/TrackedTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,12 @@ ta = TrackedArray(varr, darr, tp)

@test isa(similar(ta), Matrix{eltype(ta)})

@test samefields(ta[2], TrackedReal(varr[2], darr[2], tp, 2, ta))
for T in (UInt, Int)
@test samefields(ta[T(2)], TrackedReal(varr[2], darr[2], tp, 2, ta))
@test samefields(ta[T(2), T(3)], TrackedReal(varr[2, 3], darr[2, 3], tp, 8, ta))
S = T === UInt ? Int : UInt
@test samefields(ta[S(2), T(3)], TrackedReal(varr[2, 3], darr[2, 3], tp, 8, ta))
end

ta_sub = ta[:,:]
idx = ReverseDiff.index_iterable(axes(ta), (:, :))
Expand All @@ -630,53 +635,58 @@ instr = tp[1]
@test instr.cache === nothing
empty!(tp)

ta_sub = ta[:,1:2]
idx = ReverseDiff.index_iterable(axes(ta), (:, 1:2))
@test collect(idx) == [(i, j) for i in 1:3, j in 1:2]
@test samefields(ta_sub, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
@test instr.cache === nothing
empty!(tp)

ta_sub = ta[2:3,:]
idx = ReverseDiff.index_iterable(axes(ta), (2:3, :))
@test collect(idx) == [(i, j) for i in 2:3, j in 1:3]
@test samefields(ta_sub, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
@test instr.cache === nothing
empty!(tp)

ta_sub = ta[1:2,2:3]
idx = ReverseDiff.index_iterable(axes(ta), (1:2, 2:3))
@test collect(idx) == [(i, j) for i in 1:2, j in 2:3]
@test samefields(ta_sub, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
@test instr.cache === nothing
empty!(tp)
for T in (UInt, Int)
ta_sub = ta[:,T(1):T(2)]
idx = ReverseDiff.index_iterable(axes(ta), (:, T(1):T(2)))
@test collect(idx) == [(i, j) for i in 1:3, j in 1:2]
@test samefields(ta_sub, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[:,1:2], darr[:,1:2], tp))
@test instr.cache === nothing
empty!(tp)

ta_sub = ta[T(2):T(3),:]
idx = ReverseDiff.index_iterable(axes(ta), (T(2):T(3), :))
@test collect(idx) == [(i, j) for i in 2:3, j in 1:3]
@test samefields(ta_sub, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[2:3,:], darr[2:3,:], tp))
@test instr.cache === nothing
empty!(tp)

S = T === UInt ? Int : UInt
for U in (S, T)
ta_sub = ta[S(1):S(2),T(2):T(3)]
idx = ReverseDiff.index_iterable(axes(ta), (S(1):S(2), T(2):T(3)))
@test collect(idx) == [(i, j) for i in 1:2, j in 2:3]
@test samefields(ta_sub, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[1:2,2:3], darr[1:2,2:3], tp))
@test instr.cache === nothing
empty!(tp)
end

ta_sub = ta[2:6]
idx = ReverseDiff.index_iterable(axes(ta), (2:6,))
@test collect(idx) == [(i,) for i in 2:6]
@test samefields(ta_sub, TrackedArray(varr[2:6], darr[2:6], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[2:6], darr[2:6], tp))
@test instr.cache === nothing
empty!(tp)
ta_sub = ta[T(2):T(6)]
idx = ReverseDiff.index_iterable(axes(ta), (T(2):T(6),))
@test collect(idx) == [(i,) for i in 2:6]
@test samefields(ta_sub, TrackedArray(varr[2:6], darr[2:6], tp))
@test length(tp) == 1
instr = tp[1]
@test instr.func === getindex
@test instr.input === (ta, idx)
@test samefields(instr.output, TrackedArray(varr[2:6], darr[2:6], tp))
@test instr.cache === nothing
empty!(tp)
end

ta_sub = ta[:]
idx = ReverseDiff.index_iterable(axes(ta), (:,))
Expand Down

2 comments on commit 4a50321

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/61749

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.14.1 -m "<description of version>" 4a50321ab03eb59271d83863d420bdc73ac0f869
git push origin v1.14.1

Please sign in to comment.