Skip to content

Commit fe9ac99

Browse files
authored
Add sortperm with dims arg for AbstractArray, fixes #16273 (#45211)
1 parent e4c1b54 commit fe9ac99

File tree

2 files changed

+81
-32
lines changed

2 files changed

+81
-32
lines changed

base/sort.jl

Lines changed: 62 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using .Base: copymutable, LinearIndices, length, (:), iterate, OneTo,
1111
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
1212
extrema, sub_with_overflow, add_with_overflow, oneunit, div, getindex, setindex!,
1313
length, resize!, fill, Missing, require_one_based_indexing, keytype, UnitRange,
14-
min, max, reinterpret, signed, unsigned, Signed, Unsigned, typemin, xor, Type, BitSigned
14+
min, max, reinterpret, signed, unsigned, Signed, Unsigned, typemin, xor, Type, BitSigned, Val
1515

1616
using .Base: >>>, !==
1717

@@ -1091,14 +1091,16 @@ end
10911091
## sortperm: the permutation to sort an array ##
10921092

10931093
"""
1094-
sortperm(v; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward)
1094+
sortperm(A; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, [dims::Integer])
10951095
1096-
Return a permutation vector `I` that puts `v[I]` in sorted order. The order is specified
1096+
Return a permutation vector or array `I` that puts `A[I]` in sorted order along the given dimension.
1097+
If `A` has more than one dimension, then the `dims` keyword argument must be specified. The order is specified
10971098
using the same keywords as [`sort!`](@ref). The permutation is guaranteed to be stable even
10981099
if the sorting algorithm is unstable, meaning that indices of equal elements appear in
10991100
ascending order.
11001101
11011102
See also [`sortperm!`](@ref), [`partialsortperm`](@ref), [`invperm`](@ref), [`indexin`](@ref).
1103+
To sort slices of an array, refer to [`sortslices`](@ref).
11021104
11031105
# Examples
11041106
```jldoctest
@@ -1115,37 +1117,53 @@ julia> v[p]
11151117
1
11161118
2
11171119
3
1120+
1121+
julia> A = [8 7; 5 6]
1122+
2×2 Matrix{Int64}:
1123+
8 7
1124+
5 6
1125+
1126+
julia> sortperm(A, dims = 1)
1127+
2×2 Matrix{Int64}:
1128+
2 4
1129+
1 3
1130+
1131+
julia> sortperm(A, dims = 2)
1132+
2×2 Matrix{Int64}:
1133+
3 1
1134+
2 4
11181135
```
11191136
"""
1120-
function sortperm(v::AbstractVector;
1137+
function sortperm(A::AbstractArray;
11211138
alg::Algorithm=DEFAULT_UNSTABLE,
11221139
lt=isless,
11231140
by=identity,
11241141
rev::Union{Bool,Nothing}=nothing,
11251142
order::Ordering=Forward,
1126-
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing)
1143+
workspace::Union{AbstractVector{<:Integer}, Nothing}=nothing,
1144+
dims...) #to optionally specify dims argument
11271145
ordr = ord(lt,by,rev,order)
1128-
if ordr === Forward && isa(v,Vector) && eltype(v)<:Integer
1129-
n = length(v)
1146+
if ordr === Forward && isa(A,Vector) && eltype(A)<:Integer
1147+
n = length(A)
11301148
if n > 1
1131-
min, max = extrema(v)
1149+
min, max = extrema(A)
11321150
(diff, o1) = sub_with_overflow(max, min)
11331151
(rangelen, o2) = add_with_overflow(diff, oneunit(diff))
11341152
if !o1 && !o2 && rangelen < div(n,2)
1135-
return sortperm_int_range(v, rangelen, min)
1153+
return sortperm_int_range(A, rangelen, min)
11361154
end
11371155
end
11381156
end
1139-
p = copymutable(eachindex(v))
1140-
sort!(p, alg, Perm(ordr,v), workspace)
1157+
ix = copymutable(LinearIndices(A))
1158+
sort!(ix; alg, order = Perm(ordr, vec(A)), workspace, dims...)
11411159
end
11421160

11431161

11441162
"""
1145-
sortperm!(ix, v; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, initialized::Bool=false)
1163+
sortperm!(ix, A; alg::Algorithm=DEFAULT_UNSTABLE, lt=isless, by=identity, rev::Bool=false, order::Ordering=Forward, initialized::Bool=false, [dims::Integer])
11461164
1147-
Like [`sortperm`](@ref), but accepts a preallocated index vector `ix`. If `initialized` is `false`
1148-
(the default), `ix` is initialized to contain the values `1:length(v)`.
1165+
Like [`sortperm`](@ref), but accepts a preallocated index vector or array `ix` with the same `axes` as `A`. If `initialized` is `false`
1166+
(the default), `ix` is initialized to contain the values `LinearIndices(A)`.
11491167
11501168
# Examples
11511169
```jldoctest
@@ -1162,25 +1180,36 @@ julia> v[p]
11621180
1
11631181
2
11641182
3
1183+
1184+
julia> A = [8 7; 5 6]; p = zeros(Int,2, 2);
1185+
1186+
julia> sortperm!(p, A; dims=1); p
1187+
2×2 Matrix{Int64}:
1188+
2 4
1189+
1 3
1190+
1191+
julia> sortperm!(p, A; dims=2); p
1192+
2×2 Matrix{Int64}:
1193+
3 1
1194+
2 4
11651195
```
11661196
"""
1167-
function sortperm!(x::AbstractVector{T}, v::AbstractVector;
1197+
function sortperm!(ix::AbstractArray{T}, A::AbstractArray;
11681198
alg::Algorithm=DEFAULT_UNSTABLE,
11691199
lt=isless,
11701200
by=identity,
11711201
rev::Union{Bool,Nothing}=nothing,
11721202
order::Ordering=Forward,
11731203
initialized::Bool=false,
1174-
workspace::Union{AbstractVector{T}, Nothing}=nothing) where T <: Integer
1175-
if axes(x,1) != axes(v,1)
1176-
throw(ArgumentError("index vector must have the same length/indices as the source vector, $(axes(x,1)) != $(axes(v,1))"))
1177-
end
1204+
workspace::Union{AbstractVector{T}, Nothing}=nothing,
1205+
dims...) where T <: Integer #to optionally specify dims argument
1206+
(typeof(A) <: AbstractVector) == (:dims in keys(dims)) && throw(ArgumentError("Dims argument incorrect for type $(typeof(A))"))
1207+
axes(ix) == axes(A) || throw(ArgumentError("index array must have the same size/axes as the source array, $(axes(ix)) != $(axes(A))"))
1208+
11781209
if !initialized
1179-
@inbounds for i in eachindex(v)
1180-
x[i] = i
1181-
end
1210+
ix .= LinearIndices(A)
11821211
end
1183-
sort!(x, alg, Perm(ord(lt,by,rev,order),v), workspace)
1212+
sort!(ix; alg, order = Perm(ord(lt, by, rev, order), vec(A)), workspace, dims...)
11841213
end
11851214

11861215
# sortperm for vectors of few unique integers
@@ -1307,16 +1336,20 @@ function sort!(A::AbstractArray{T};
13071336
rev::Union{Bool,Nothing}=nothing,
13081337
order::Ordering=Forward,
13091338
workspace::Union{AbstractVector{T}, Nothing}=similar(A, size(A, dims))) where T
1310-
ordr = ord(lt, by, rev, order)
1339+
_sort!(A, Val(dims), alg, ord(lt, by, rev, order), workspace)
1340+
end
1341+
function _sort!(A::AbstractArray{T}, ::Val{K},
1342+
alg::Algorithm,
1343+
order::Ordering,
1344+
workspace::Union{AbstractVector{T}, Nothing}) where {K,T}
13111345
nd = ndims(A)
1312-
k = dims
13131346

1314-
1 <= k <= nd || throw(ArgumentError("dimension out of range"))
1347+
1 <= K <= nd || throw(ArgumentError("dimension out of range"))
13151348

1316-
remdims = ntuple(i -> i == k ? 1 : axes(A, i), nd)
1349+
remdims = ntuple(i -> i == K ? 1 : axes(A, i), nd)
13171350
for idx in CartesianIndices(remdims)
1318-
Av = view(A, ntuple(i -> i == k ? Colon() : idx[i], nd)...)
1319-
sort!(Av, alg, ordr, workspace)
1351+
Av = view(A, ntuple(i -> i == K ? Colon() : idx[i], nd)...)
1352+
sort!(Av, alg, order, workspace)
13201353
end
13211354
A
13221355
end

test/sorting.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,25 @@ end
4747
@test r == [3,1,2]
4848
@test r === s
4949
end
50-
@test_throws ArgumentError sortperm!(view([1,2,3,4], 1:4), [2,3,1])
51-
@test sortperm(OffsetVector([8.0,-2.0,0.5], -4)) == OffsetVector([-2, -1, -3], -4)
52-
@test sortperm!(Int32[1,2], [2.0, 1.0]) == Int32[2, 1]
50+
@test_throws ArgumentError sortperm!(view([1, 2, 3, 4], 1:4), [2, 3, 1])
51+
@test sortperm(OffsetVector([8.0, -2.0, 0.5], -4)) == OffsetVector([-2, -1, -3], -4)
52+
@test sortperm!(Int32[1, 2], [2.0, 1.0]) == Int32[2, 1]
53+
@test_throws ArgumentError sortperm!(Int32[1, 2], [2.0, 1.0]; dims=1)
54+
let A = rand(4, 4, 4)
55+
for dims = 1:3
56+
perm = sortperm(A; dims)
57+
sorted = sort(A; dims)
58+
@test A[perm] == sorted
59+
60+
perm_idx = similar(Array{Int}, axes(A))
61+
sortperm!(perm_idx, A; dims)
62+
@test perm_idx == perm
63+
end
64+
end
65+
@test_throws ArgumentError sortperm!(zeros(Int, 3, 3), rand(3, 3);)
66+
@test_throws ArgumentError sortperm!(zeros(Int, 3, 3), rand(3, 3); dims=3)
67+
@test_throws ArgumentError sortperm!(zeros(Int, 3, 4), rand(4, 4); dims=1)
68+
@test_throws ArgumentError sortperm!(OffsetArray(zeros(Int, 4, 4), -4:-1, 1:4), rand(4, 4); dims=1)
5369
end
5470

5571
@testset "misc sorting" begin

0 commit comments

Comments
 (0)