-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
diag of SparseMatrixCSC should always return SparseVector #23261
Changes from 5 commits
0c7863b
7af9ab2
00905db
a77a206
17eef91
5530e4d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3380,40 +3380,38 @@ function expandptr(V::Vector{<:Integer}) | |
res | ||
end | ||
|
||
## diag and related using an iterator | ||
|
||
mutable struct SpDiagIterator{Tv,Ti} | ||
A::SparseMatrixCSC{Tv,Ti} | ||
n::Int | ||
end | ||
SpDiagIterator(A::SparseMatrixCSC) = SpDiagIterator(A,minimum(size(A))) | ||
|
||
length(d::SpDiagIterator) = d.n | ||
start(d::SpDiagIterator) = 1 | ||
done(d::SpDiagIterator, j) = j > d.n | ||
|
||
function next(d::SpDiagIterator{Tv}, j) where Tv | ||
A = d.A | ||
r1 = Int(A.colptr[j]) | ||
r2 = Int(A.colptr[j+1]-1) | ||
(r1 > r2) && (return (zero(Tv), j+1)) | ||
r1 = searchsortedfirst(A.rowval, j, r1, r2, Forward) | ||
(((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1) | ||
function diag(A::SparseMatrixCSC{Tv,Ti}, d::Integer=0) where {Tv,Ti} | ||
m, n = size(A) | ||
if !(-m <= d <= n) | ||
throw(ArgumentError("requested diagonal, $d, out of bounds in matrix of size ($m, $n)")) | ||
end | ||
l = d < 0 ? min(m+d,n) : min(n-d,m) | ||
r, c = d <= 0 ? (-d, 0) : (0, d) # start row/col -1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps either type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Converted |
||
ind = Vector{Ti}() | ||
val = Vector{Tv}() | ||
for i in 1:l | ||
r += 1; c += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likewise here, either type There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also fixed by converting to |
||
r1 = Int(A.colptr[c]) | ||
r2 = Int(A.colptr[c+1]-1) | ||
r1 > r2 && continue | ||
r1 = searchsortedfirst(A.rowval, r, r1, r2, Forward) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC |
||
((r1 > r2) || (A.rowval[r1] != r)) && continue | ||
push!(ind, i) | ||
push!(val, A.nzval[r1]) | ||
end | ||
return SparseVector{Tv,Ti}(l, ind, val) | ||
end | ||
|
||
function trace(A::SparseMatrixCSC{Tv}) where Tv | ||
if size(A,1) != size(A,2) | ||
throw(DimensionMismatch("expected square matrix")) | ||
end | ||
n = checksquare(A) | ||
s = zero(Tv) | ||
for d in SpDiagIterator(A) | ||
s += d | ||
for i in 1:n | ||
s += A[i,i] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this version perform similarly to the original? Sparse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't it do pretty much exactly what the sparse iterator what doing? You still gotta search, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I expect so, at least roughly. But my conjectures often fail, so explicit verification has become my friend :). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some benchmarking suggested this was faster. The iterator was even allocating some stuff. I will get back with number :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using BenchmarkTools
function trace2(A::SparseMatrixCSC{Tv}) where Tv
n = Base.LinAlg.checksquare(A)
s = zero(Tv)
for i in 1:n
s += A[i,i]
end
return s
end
for s in (1000, 5000), p in (0.1, 0.01, 0.005)
S1 = sprand(s, s, p)
S2 = S1 + speye(s, s) # typical case with values on all diagonal positions
println("trace")
@btime trace($S1)
@btime trace($S2)
println("trace2")
@btime trace2($S1)
@btime trace2($S2)
end with output:
|
||
end | ||
s | ||
return s | ||
end | ||
|
||
diag(A::SparseMatrixCSC{Tv}) where {Tv} = Tv[d for d in SpDiagIterator(A)] | ||
|
||
function diagm(v::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} | ||
if size(v,1) != 1 && size(v,2) != 1 | ||
throw(DimensionMismatch("input should be nx1 or 1xn")) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1323,6 +1323,27 @@ end | |
@test diagm(sparse(ones(5,1))) == speye(5) | ||
end | ||
|
||
@testset "diag" begin | ||
for T in (Float64, Complex128) | ||
S1 = sprand(T, 5, 5, 0.5) | ||
S2 = sprand(T, 10, 5, 0.5) | ||
S3 = sprand(T, 5, 10, 0.5) | ||
for S in (S1, S2, S3) | ||
A = Matrix(S) | ||
@test diag(S)::SparseVector{T,Int} == diag(A) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra space before |
||
for k in -size(S,1):size(S,2) | ||
@test diag(S, k)::SparseVector{T,Int} == diag(A, k) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Extra space before |
||
end | ||
@test_throws ArgumentError diag(S, -size(S,1)-1) | ||
@test_throws ArgumentError diag(S, size(S,2)+1) | ||
end | ||
end | ||
# test that stored zeros are still stored zeros in the diagonal | ||
S = sparse([1,3],[1,3],[0.0,0.0]); V = diag(S) | ||
@test V.nzind == [1,3] | ||
@test V.nzval == [0.0,0.0] | ||
end | ||
|
||
@testset "expandptr" begin | ||
A = speye(5) | ||
@test Base.SparseArrays.expandptr(A.colptr) == collect(1:5) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this to avoid changing the behavior? I guess it should still work with a sparse vector, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it works, but will result in a
SparseMatrixCSC
, so the result is not inferable.