diff --git a/base/multidimensional.jl b/base/multidimensional.jl index d8646e4c261ed..0ca2fbc36e6df 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -1897,39 +1897,25 @@ julia> sortslices(reshape([5; 4; 3; 2; 1], (1,1,5)), dims=3, by=x->x[1,1]) ``` """ function sortslices(A::AbstractArray; dims::Union{Integer, Tuple{Vararg{Integer}}}, kws...) - _sortslices(A, Val{dims}(); kws...) -end + if A isa Matrix && dims isa Integer && dims == 1 + # TODO: remove once the generic version becomes as fast or faster + perm = sortperm(eachslice(A; dims); kws...) + return A[perm, :] + end -# Works around inference's lack of ability to recognize partial constness -struct DimSelector{dims, T} - A::T + B = similar(A) + _sortslices!(B, A, Val{dims}(); kws...) + B end -DimSelector{dims}(x::T) where {dims, T} = DimSelector{dims, T}(x) -(ds::DimSelector{dims, T})(i) where {dims, T} = i in dims ? axes(ds.A, i) : (:,) - -_negdims(n, dims) = filter(i->!(i in dims), 1:n) -function compute_itspace(A, ::Val{dims}) where {dims} - negdims = _negdims(ndims(A), dims) - axs = Iterators.product(ntuple(DimSelector{dims}(A), ndims(A))...) - vec(permutedims(collect(axs), (dims..., negdims...))) -end +function _sortslices!(B, A, ::Val{dims}; kws...) where dims + ves = vec(eachslice(A; dims)) + perm = sortperm(ves; kws...) + bes = eachslice(B; dims) -function _sortslices(A::AbstractArray, d::Val{dims}; kws...) where dims - itspace = compute_itspace(A, d) - vecs = map(its->view(A, its...), itspace) - p = sortperm(vecs; kws...) - if ndims(A) == 2 && isa(dims, Integer) && isa(A, Array) - # At the moment, the performance of the generic version is subpar - # (about 5x slower). Hardcode a fast-path until we're able to - # optimize this. - return dims == 1 ? A[p, :] : A[:, p] - else - B = similar(A) - for (x, its) in zip(p, itspace) - B[its...] = vecs[x] - end - B + # TODO for further optimization: traverse in memory order + for (slice, i) in zip(eachslice(B; dims), perm) + slice .= ves[i] end end diff --git a/test/arrayops.jl b/test/arrayops.jl index 30f00ca5d5a46..9c9aac987929d 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -1446,6 +1446,15 @@ end @test sortslices(B, dims=(1,3)) == B end +@testset "sortslices inference (#52019)" begin + x = rand(3, 2) + @inferred sortslices(x, dims=1) + @inferred sortslices(x, dims=(2,)) + x = rand(1, 2, 3) + @inferred sortslices(x, dims=(1,2)) + @inferred sortslices(x, dims=3, by=sum) +end + @testset "fill" begin @test fill!(Float64[1.0], -0.0)[1] === -0.0 A = fill(1.,3,3)