Skip to content

Commit

Permalink
Rewrite sortslices from scratch and add inference tests (#52039)
Browse files Browse the repository at this point in the history
  • Loading branch information
LilithHafner authored Nov 14, 2023
1 parent 9754dbb commit 2d449d4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 29 deletions.
44 changes: 15 additions & 29 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1897,39 +1897,25 @@ julia> sortslices(reshape([5; 4; 3; 2; 1], (1,1,5)), dims=3, by=x->x[1,1])
```
"""
function sortslices(A::AbstractArray; dims::Union{Integer, Tuple{Vararg{Integer}}}, kws...)
_sortslices(A, Val{dims}(); kws...)
end
if A isa Matrix && dims isa Integer && dims == 1
# TODO: remove once the generic version becomes as fast or faster
perm = sortperm(eachslice(A; dims); kws...)
return A[perm, :]
end

# Works around inference's lack of ability to recognize partial constness
struct DimSelector{dims, T}
A::T
B = similar(A)
_sortslices!(B, A, Val{dims}(); kws...)
B
end
DimSelector{dims}(x::T) where {dims, T} = DimSelector{dims, T}(x)
(ds::DimSelector{dims, T})(i) where {dims, T} = i in dims ? axes(ds.A, i) : (:,)

_negdims(n, dims) = filter(i->!(i in dims), 1:n)

function compute_itspace(A, ::Val{dims}) where {dims}
negdims = _negdims(ndims(A), dims)
axs = Iterators.product(ntuple(DimSelector{dims}(A), ndims(A))...)
vec(permutedims(collect(axs), (dims..., negdims...)))
end
function _sortslices!(B, A, ::Val{dims}; kws...) where dims
ves = vec(eachslice(A; dims))
perm = sortperm(ves; kws...)
bes = eachslice(B; dims)

function _sortslices(A::AbstractArray, d::Val{dims}; kws...) where dims
itspace = compute_itspace(A, d)
vecs = map(its->view(A, its...), itspace)
p = sortperm(vecs; kws...)
if ndims(A) == 2 && isa(dims, Integer) && isa(A, Array)
# At the moment, the performance of the generic version is subpar
# (about 5x slower). Hardcode a fast-path until we're able to
# optimize this.
return dims == 1 ? A[p, :] : A[:, p]
else
B = similar(A)
for (x, its) in zip(p, itspace)
B[its...] = vecs[x]
end
B
# TODO for further optimization: traverse in memory order
for (slice, i) in zip(eachslice(B; dims), perm)
slice .= ves[i]
end
end

Expand Down
9 changes: 9 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1446,6 +1446,15 @@ end
@test sortslices(B, dims=(1,3)) == B
end

@testset "sortslices inference (#52019)" begin
x = rand(3, 2)
@inferred sortslices(x, dims=1)
@inferred sortslices(x, dims=(2,))
x = rand(1, 2, 3)
@inferred sortslices(x, dims=(1,2))
@inferred sortslices(x, dims=3, by=sum)
end

@testset "fill" begin
@test fill!(Float64[1.0], -0.0)[1] === -0.0
A = fill(1.,3,3)
Expand Down

0 comments on commit 2d449d4

Please sign in to comment.