Skip to content

Commit

Permalink
Merge pull request #49 from JuliaGPU/tb/logicalindices_count
Browse files Browse the repository at this point in the history
Optimize LogicalIndices construction.
  • Loading branch information
maleadt authored Jan 6, 2022
2 parents 23a1541 + e493a03 commit e2aa20e
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ export WrappedArray

adapt_structure(to, A::SubArray) =
SubArray(adapt(to, Base.parent(A)), adapt(to, parentindices(A)))
adapt_structure(to, A::Base.LogicalIndex) =
Base.LogicalIndex(adapt(to, A.mask))
adapt_structure(to, A::PermutedDimsArray) =
PermutedDimsArray(adapt(to, Base.parent(A)), permutation(A))
adapt_structure(to, A::Base.ReshapedArray) =
Expand All @@ -24,6 +22,11 @@ else
adapt_structure(to, A::Base.ReinterpretArray) =
Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A)))
end
@eval function adapt_structure(to, A::Base.LogicalIndex{T}) where T
# prevent re-calculating the count of booleans during LogicalIndex construction
mask = adapt(to, A.mask)
$(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum)))
end

adapt_structure(to, A::LinearAlgebra.Adjoint) =
LinearAlgebra.adjoint(adapt(to, Base.parent(A)))
Expand Down

0 comments on commit e2aa20e

Please sign in to comment.