Skip to content

Commit 7c1c298

Browse files
committed
use pairs in findmin and findmax, supporting all indexable collections
return `CartesianIndex` for n-d arrays in findmin, findmax, indmin, indmax more compact printing of `CartesianIndex` change sparse `_findr` macro to a function
1 parent baa9a70 commit 7c1c298

File tree

8 files changed

+130
-119
lines changed

8 files changed

+130
-119
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,10 @@ This section lists changes that do not have deprecation warnings.
191191
This avoids stack overflows in the common case of definitions like
192192
`f(x, y) = f(promote(x, y)...)` ([#22801]).
193193

194+
* `findmin`, `findmax`, `indmin`, and `indmax` used to always return linear indices.
195+
They now return `CartesianIndex`es for all but 1-d arrays, and in general return
196+
the `keys` of indexed collections (e.g. dictionaries) ([#22907]).
197+
194198
Library improvements
195199
--------------------
196200

base/array.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,13 +2072,13 @@ function findmax(a)
20722072
if isempty(a)
20732073
throw(ArgumentError("collection must be non-empty"))
20742074
end
2075-
s = start(a)
2076-
mi = i = 1
2077-
m, s = next(a, s)
2078-
while !done(a, s)
2075+
p = pairs(a)
2076+
s = start(p)
2077+
(mi, m), s = next(p, s)
2078+
i = mi
2079+
while !done(p, s)
20792080
m != m && break
2080-
ai, s = next(a, s)
2081-
i += 1
2081+
(i, ai), s = next(p, s)
20822082
if ai != ai || isless(m, ai)
20832083
m = ai
20842084
mi = i
@@ -2113,13 +2113,13 @@ function findmin(a)
21132113
if isempty(a)
21142114
throw(ArgumentError("collection must be non-empty"))
21152115
end
2116-
s = start(a)
2117-
mi = i = 1
2118-
m, s = next(a, s)
2119-
while !done(a, s)
2116+
p = pairs(a)
2117+
s = start(p)
2118+
(mi, m), s = next(p, s)
2119+
i = mi
2120+
while !done(p, s)
21202121
m != m && break
2121-
ai, s = next(a, s)
2122-
i += 1
2122+
(i, ai), s = next(p, s)
21232123
if ai != ai || isless(ai, m)
21242124
m = ai
21252125
mi = i

base/multidimensional.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
module IteratorsMD
55
import Base: eltype, length, size, start, done, next, first, last, in, getindex,
66
setindex!, IndexStyle, min, max, zero, one, isless, eachindex,
7-
ndims, iteratorsize, convert
7+
ndims, iteratorsize, convert, show
88

99
import Base: +, -, *
1010
import Base: simd_outer_range, simd_inner_length, simd_index
@@ -80,6 +80,7 @@ module IteratorsMD
8080
@inline _flatten(i, I...) = (i, _flatten(I...)...)
8181
@inline _flatten(i::CartesianIndex, I...) = (i.I..., _flatten(I...)...)
8282
CartesianIndex(index::Tuple{Vararg{Union{Integer, CartesianIndex}}}) = CartesianIndex(index...)
83+
show(io::IO, i::CartesianIndex) = (print(io, "CartesianIndex"); show(io, i.I))
8384

8485
# length
8586
length(::CartesianIndex{N}) where {N} = N

base/reducedim.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -635,20 +635,22 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
635635
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
636636
indsAt, indsRt = safe_tail(indices(A)), safe_tail(indices(Rval))
637637
keep, Idefault = Broadcast.shapeindexer(indsAt, indsRt)
638-
k = 0
638+
ks = keys(A)
639+
k, kss = next(ks, start(ks))
640+
zi = zero(eltype(ks))
639641
if reducedim1(Rval, A)
640642
i1 = first(indices1(Rval))
641643
@inbounds for IA in CartesianRange(indsAt)
642644
IR = Broadcast.newindex(IA, keep, Idefault)
643645
tmpRv = Rval[i1,IR]
644646
tmpRi = Rind[i1,IR]
645647
for i in indices(A,1)
646-
k += 1
647648
tmpAv = A[i,IA]
648-
if tmpRi == 0 || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
649+
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
649650
tmpRv = tmpAv
650651
tmpRi = k
651652
end
653+
k, kss = next(ks, kss)
652654
end
653655
Rval[i1,IR] = tmpRv
654656
Rind[i1,IR] = tmpRi
@@ -657,14 +659,14 @@ function findminmax!(f, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
657659
@inbounds for IA in CartesianRange(indsAt)
658660
IR = Broadcast.newindex(IA, keep, Idefault)
659661
for i in indices(A, 1)
660-
k += 1
661662
tmpAv = A[i,IA]
662663
tmpRv = Rval[i,IR]
663664
tmpRi = Rind[i,IR]
664-
if tmpRi == 0 || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
665+
if tmpRi == zi || (tmpRv == tmpRv && (tmpAv != tmpAv || f(tmpAv, tmpRv)))
665666
Rval[i,IR] = tmpAv
666667
Rind[i,IR] = k
667668
end
669+
k, kss = next(ks, kss)
668670
end
669671
end
670672
end
@@ -680,7 +682,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
680682
"""
681683
function findmin!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
682684
init::Bool=true)
683-
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,0), A)
685+
findminmax!(isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
684686
end
685687

686688
"""
@@ -709,10 +711,10 @@ function findmin(A::AbstractArray{T}, region) where T
709711
if prod(map(length, reduced_indices(A, region))) != 0
710712
throw(ArgumentError("collection slices must be non-empty"))
711713
end
712-
(similar(A, ri), similar(dims->zeros(Int, dims), ri))
714+
(similar(A, ri), similar(dims->zeros(eltype(keys(A)), dims), ri))
713715
else
714716
findminmax!(isless, fill!(similar(A, ri), first(A)),
715-
similar(dims->zeros(Int, dims), ri), A)
717+
similar(dims->zeros(eltype(keys(A)), dims), ri), A)
716718
end
717719
end
718720

@@ -727,7 +729,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
727729
"""
728730
function findmax!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
729731
init::Bool=true)
730-
findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,0), A)
732+
findminmax!(isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
731733
end
732734

733735
"""
@@ -756,10 +758,10 @@ function findmax(A::AbstractArray{T}, region) where T
756758
if prod(map(length, reduced_indices(A, region))) != 0
757759
throw(ArgumentError("collection slices must be non-empty"))
758760
end
759-
similar(A, ri), similar(dims->zeros(Int, dims), ri)
761+
similar(A, ri), similar(dims->zeros(eltype(keys(A)), dims), ri)
760762
else
761763
findminmax!(isgreater, fill!(similar(A, ri), first(A)),
762-
similar(dims->zeros(Int, dims), ri), A)
764+
similar(dims->zeros(eltype(keys(A)), dims), ri), A)
763765
end
764766
end
765767

base/sparse/sparsematrix.jl

Lines changed: 39 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,101 +1815,100 @@ function _findz(A::SparseMatrixCSC{Tv,Ti}, rows=1:A.m, cols=1:A.n) where {Tv,Ti}
18151815
row = 0
18161816
rowmin = rows[1]; rowmax = rows[end]
18171817
allrows = (rows == 1:A.m)
1818-
@inbounds for col in cols
1818+
@inbounds for col in cols
18191819
r1::Int = colptr[col]
18201820
r2::Int = colptr[col+1] - 1
18211821
if !allrows && (r1 <= r2)
18221822
r1 = searchsortedfirst(rowval, rowmin, r1, r2, Forward)
18231823
(r1 <= r2 ) && (r2 = searchsortedlast(rowval, rowmax, r1, r2, Forward))
18241824
end
18251825
row = rowmin
1826-
18271826
while (r1 <= r2) && (row == rowval[r1]) && (nzval[r1] != zval)
18281827
r1 += 1
18291828
row += 1
18301829
end
1831-
(row <= rowmax) && (return sub2ind(size(A), row, col))
1830+
(row <= rowmax) && (return CartesianIndex(row, col))
18321831
end
1833-
return 0
1832+
return CartesianIndex(0, 0)
18341833
end
18351834

1836-
macro _findr(op, A, region, Tv, Ti)
1837-
esc(quote
1838-
N = nnz($A)
1839-
L = length($A)
1835+
function _findr(op, A, region, Tv)
1836+
Ti = eltype(keys(A))
1837+
i1 = first(keys(A))
1838+
N = nnz(A)
1839+
L = length(A)
18401840
if L == 0
1841-
if prod(map(length, Base.reduced_indices($A, $region))) != 0
1841+
if prod(map(length, Base.reduced_indices(A, region))) != 0
18421842
throw(ArgumentError("array slices must be non-empty"))
18431843
else
1844-
ri = Base.reduced_indices0($A, $region)
1845-
return (similar($A, ri), similar(dims->zeros(Int, dims), ri))
1844+
ri = Base.reduced_indices0(A, region)
1845+
return (similar(A, ri), similar(dims->zeros(Ti, dims), ri))
18461846
end
18471847
end
18481848

1849-
colptr = $A.colptr; rowval = $A.rowval; nzval = $A.nzval; m = $A.m; n = $A.n
1850-
zval = zero($Tv)
1851-
szA = size($A)
1849+
colptr = A.colptr; rowval = A.rowval; nzval = A.nzval; m = A.m; n = A.n
1850+
zval = zero(Tv)
1851+
szA = size(A)
18521852

1853-
if $region == 1 || $region == (1,)
1854-
(N == 0) && (return (fill(zval,1,n), fill(convert($Ti,1),1,n)))
1855-
S = Vector{$Tv}(n); I = Vector{$Ti}(n)
1853+
if region == 1 || region == (1,)
1854+
(N == 0) && (return (fill(zval,1,n), fill(i1,1,n)))
1855+
S = Vector{Tv}(n); I = Vector{Ti}(n)
18561856
@inbounds for i = 1 : n
1857-
Sc = zval; Ic = _findz($A, 1:m, i:i)
1858-
if Ic == 0
1857+
Sc = zval; Ic = _findz(A, 1:m, i:i)
1858+
if Ic == CartesianIndex(0, 0)
18591859
j = colptr[i]
1860-
Ic = sub2ind(szA, rowval[j], i)
1860+
Ic = CartesianIndex(rowval[j], i)
18611861
Sc = nzval[j]
18621862
end
18631863
for j = colptr[i] : colptr[i+1]-1
1864-
if ($op)(nzval[j], Sc)
1864+
if op(nzval[j], Sc)
18651865
Sc = nzval[j]
1866-
Ic = sub2ind(szA, rowval[j], i)
1866+
Ic = CartesianIndex(rowval[j], i)
18671867
end
18681868
end
18691869
S[i] = Sc; I[i] = Ic
18701870
end
18711871
return(reshape(S,1,n), reshape(I,1,n))
1872-
elseif $region == 2 || $region == (2,)
1873-
(N == 0) && (return (fill(zval,m,1), fill(convert($Ti,1),m,1)))
1874-
S = Vector{$Tv}(m); I = Vector{$Ti}(m)
1872+
elseif region == 2 || region == (2,)
1873+
(N == 0) && (return (fill(zval,m,1), fill(i1,m,1)))
1874+
S = Vector{Tv}(m); I = Vector{Ti}(m)
18751875
@inbounds for row in 1:m
1876-
S[row] = zval; I[row] = _findz($A, row:row, 1:n)
1877-
if I[row] == 0
1878-
I[row] = sub2ind(szA, row, 1)
1876+
S[row] = zval; I[row] = _findz(A, row:row, 1:n)
1877+
if I[row] == CartesianIndex(0, 0)
1878+
I[row] = CartesianIndex(row, 1)
18791879
S[row] = A[row,1]
18801880
end
18811881
end
18821882
@inbounds for i = 1 : n, j = colptr[i] : colptr[i+1]-1
18831883
row = rowval[j]
1884-
if ($op)(nzval[j], S[row])
1884+
if op(nzval[j], S[row])
18851885
S[row] = nzval[j]
1886-
I[row] = sub2ind(szA, row, i)
1886+
I[row] = CartesianIndex(row, i)
18871887
end
18881888
end
18891889
return (reshape(S,m,1), reshape(I,m,1))
1890-
elseif $region == (1,2)
1891-
(N == 0) && (return (fill(zval,1,1), fill(convert($Ti,1),1,1)))
1892-
hasz = nnz($A) != length($A)
1890+
elseif region == (1,2)
1891+
(N == 0) && (return (fill(zval,1,1), fill(i1,1,1)))
1892+
hasz = nnz(A) != length(A)
18931893
Sv = hasz ? zval : nzval[1]
1894-
Iv::($Ti) = hasz ? _findz($A) : 1
1895-
@inbounds for i = 1 : $A.n, j = colptr[i] : (colptr[i+1]-1)
1896-
if ($op)(nzval[j], Sv)
1894+
Iv::(Ti) = hasz ? _findz(A) : i1
1895+
@inbounds for i = 1 : A.n, j = colptr[i] : (colptr[i+1]-1)
1896+
if op(nzval[j], Sv)
18971897
Sv = nzval[j]
1898-
Iv = sub2ind(szA, rowval[j], i)
1898+
Iv = CartesianIndex(rowval[j], i)
18991899
end
19001900
end
19011901
return (fill(Sv,1,1), fill(Iv,1,1))
19021902
else
19031903
throw(ArgumentError("invalid value for region; must be 1, 2, or (1,2)"))
19041904
end
1905-
end) #quote
19061905
end
19071906

19081907
_isless_fm(a, b) = b == b && ( a != a || isless(a, b) )
19091908
_isgreater_fm(a, b) = b == b && ( a != a || isless(b, a) )
19101909

1911-
findmin(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = @_findr(_isless_fm, A, region, Tv, Ti)
1912-
findmax(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = @_findr(_isgreater_fm, A, region, Tv, Ti)
1910+
findmin(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isless_fm, A, region, Tv)
1911+
findmax(A::SparseMatrixCSC{Tv,Ti}, region) where {Tv,Ti} = _findr(_isgreater_fm, A, region, Tv)
19131912
findmin(A::SparseMatrixCSC) = (r=findmin(A,(1,2)); (r[1][1], r[2][1]))
19141913
findmax(A::SparseMatrixCSC) = (r=findmax(A,(1,2)); (r[1][1], r[2][1]))
19151914

test/arrayops.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ end
503503
@test indmin(5:-2:1) == 3
504504

505505
#23094
506-
@test findmax(Set(["abc"])) === ("abc", 1)
506+
@test_throws MethodError findmax(Set(["abc"]))
507507
@test findmin(["abc", "a"]) === ("a", 2)
508508
@test_throws MethodError findmax([Set([1]), Set([2])])
509509
@test findmin([0.0, -0.0]) === (-0.0, 2)
@@ -1814,6 +1814,11 @@ s, si = findmax(S)
18141814
@test a == b == s
18151815
@test ai == bi == si
18161816

1817+
for X in (A, B, S)
1818+
@test findmin(X) == findmin(Dict(pairs(X)))
1819+
@test findmax(X) == findmax(Dict(pairs(X)))
1820+
end
1821+
18171822
fill!(B, 2)
18181823
@test all(x->x==2, B)
18191824

0 commit comments

Comments
 (0)