From e493a031beee4d0f5fc41a5360d5800ed528a8b5 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 5 Jan 2022 17:47:18 +0100 Subject: [PATCH] Optimize LogicalIndices construction. Copy the old sum field instead of recalculating it. --- src/wrappers.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/wrappers.jl b/src/wrappers.jl index b103569..42babad 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -9,8 +9,6 @@ export WrappedArray adapt_structure(to, A::SubArray) = SubArray(adapt(to, Base.parent(A)), adapt(to, parentindices(A))) -adapt_structure(to, A::Base.LogicalIndex) = - Base.LogicalIndex(adapt(to, A.mask)) adapt_structure(to, A::PermutedDimsArray) = PermutedDimsArray(adapt(to, Base.parent(A)), permutation(A)) adapt_structure(to, A::Base.ReshapedArray) = @@ -24,6 +22,11 @@ else adapt_structure(to, A::Base.ReinterpretArray) = Base.reinterpret(Base.eltype(A), adapt(to, Base.parent(A))) end +@eval function adapt_structure(to, A::Base.LogicalIndex{T}) where T + # prevent re-calculating the count of booleans during LogicalIndex construction + mask = adapt(to, A.mask) + $(Expr(:new, :(Base.LogicalIndex{T, typeof(mask)}), :mask, :(A.sum))) +end adapt_structure(to, A::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(adapt(to, Base.parent(A)))