Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

diag of SparseMatrixCSC should always return SparseVector #23261

Merged
merged 6 commits into from
Aug 18, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ Deprecated or removed
* `Base.cpad` has been removed; use an appropriate combination of `rpad` and `lpad`
instead ([#23187]).

* `Base.SparseArrays.SpDiagIterator` has been removed ([#23261]).

Command-line option changes
---------------------------

Expand Down
2 changes: 1 addition & 1 deletion base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ for f in (:\, :Ac_ldiv_B, :At_ldiv_B)
if m == n
if istril(A)
if istriu(A)
return ($f)(Diagonal(A), B)
return ($f)(Diagonal(Vector(diag(A))), B)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this to avoid changing the behavior? I guess it should still work with a sparse vector, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it works, but will result in a SparseMatrixCSC, so the result is not inferable.

else
return ($f)(LowerTriangular(A), B)
end
Expand Down
50 changes: 24 additions & 26 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3380,40 +3380,38 @@ function expandptr(V::Vector{<:Integer})
res
end

## diag and related using an iterator

mutable struct SpDiagIterator{Tv,Ti}
A::SparseMatrixCSC{Tv,Ti}
n::Int
end
SpDiagIterator(A::SparseMatrixCSC) = SpDiagIterator(A,minimum(size(A)))

length(d::SpDiagIterator) = d.n
start(d::SpDiagIterator) = 1
done(d::SpDiagIterator, j) = j > d.n

function next(d::SpDiagIterator{Tv}, j) where Tv
A = d.A
r1 = Int(A.colptr[j])
r2 = Int(A.colptr[j+1]-1)
(r1 > r2) && (return (zero(Tv), j+1))
r1 = searchsortedfirst(A.rowval, j, r1, r2, Forward)
(((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1)
function diag(A::SparseMatrixCSC{Tv,Ti}, d::Integer=0) where {Tv,Ti}
m, n = size(A)
if !(-m <= d <= n)
throw(ArgumentError("requested diagonal, $d, out of bounds in matrix of size ($m, $n)"))
end
l = d < 0 ? min(m+d,n) : min(n-d,m)
r, c = d <= 0 ? (-d, 0) : (0, d) # start row/col -1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps either type r and c or use zero to on the right side to avoid potential type instability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted d directly to Int so should be ok now :)

ind = Vector{Ti}()
val = Vector{Tv}()
for i in 1:l
r += 1; c += 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise here, either type r and c or use oneunit on the right side to avoid potential type instability?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also fixed by converting to Int directly.

r1 = Int(A.colptr[c])
r2 = Int(A.colptr[c+1]-1)
r1 > r2 && continue
r1 = searchsortedfirst(A.rowval, r, r1, r2, Forward)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC searchsortedfirst always returns an Int?

((r1 > r2) || (A.rowval[r1] != r)) && continue
push!(ind, i)
push!(val, A.nzval[r1])
end
return SparseVector{Tv,Ti}(l, ind, val)
end

function trace(A::SparseMatrixCSC{Tv}) where Tv
if size(A,1) != size(A,2)
throw(DimensionMismatch("expected square matrix"))
end
n = checksquare(A)
s = zero(Tv)
for d in SpDiagIterator(A)
s += d
for i in 1:n
s += A[i,i]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this version perform similarly to the original? Sparse getindex is fairly complex / expensive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it do pretty much exactly what the sparse iterator what doing? You still gotta search, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I expect so, at least roughly. But my conjectures often fail, so explicit verification has become my friend :).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some benchmarking suggested this was faster. The iterator was even allocating some stuff. I will get back with number :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using BenchmarkTools
function trace2(A::SparseMatrixCSC{Tv}) where Tv
    n = Base.LinAlg.checksquare(A)
    s = zero(Tv)
    for i in 1:n
        s += A[i,i]
    end
    return s
end

for s in (1000, 5000), p in (0.1, 0.01, 0.005)
    S1 = sprand(s, s, p)
    S2 = S1 + speye(s, s) # typical case with values on all diagonal positions
    println("trace")
    @btime trace($S1)
    @btime trace($S2)
    println("trace2")
    @btime trace2($S1)
    @btime trace2($S2)
end

with output:

trace
  24.820 μs (6 allocations: 176 bytes)
  26.273 μs (6 allocations: 176 bytes)
trace2
  25.804 μs (0 allocations: 0 bytes)
  25.010 μs (0 allocations: 0 bytes)
trace
  11.110 μs (6 allocations: 176 bytes)
  11.680 μs (6 allocations: 176 bytes)
trace2
  9.736 μs (0 allocations: 0 bytes)
  10.217 μs (0 allocations: 0 bytes)
trace
  9.370 μs (6 allocations: 176 bytes)
  9.862 μs (6 allocations: 176 bytes)
trace2
  7.443 μs (0 allocations: 0 bytes)
  7.938 μs (0 allocations: 0 bytes)
trace
  204.548 μs (6 allocations: 176 bytes)
  218.412 μs (6 allocations: 176 bytes)
trace2
  197.253 μs (0 allocations: 0 bytes)
  210.470 μs (0 allocations: 0 bytes)
trace
  117.837 μs (6 allocations: 176 bytes)
  120.147 μs (6 allocations: 176 bytes)
trace2
  112.441 μs (0 allocations: 0 bytes)
  113.419 μs (0 allocations: 0 bytes)
trace
  104.461 μs (6 allocations: 176 bytes)
  106.751 μs (6 allocations: 176 bytes)
trace2
  97.655 μs (0 allocations: 0 bytes)
  100.329 μs (0 allocations: 0 bytes)

end
s
return s
end

diag(A::SparseMatrixCSC{Tv}) where {Tv} = Tv[d for d in SpDiagIterator(A)]

function diagm(v::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
if size(v,1) != 1 && size(v,2) != 1
throw(DimensionMismatch("input should be nx1 or 1xn"))
Expand Down
21 changes: 21 additions & 0 deletions test/sparse/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,27 @@ end
@test diagm(sparse(ones(5,1))) == speye(5)
end

@testset "diag" begin
for T in (Float64, Complex128)
S1 = sprand(T, 5, 5, 0.5)
S2 = sprand(T, 10, 5, 0.5)
S3 = sprand(T, 5, 10, 0.5)
for S in (S1, S2, S3)
A = Matrix(S)
@test diag(S)::SparseVector{T,Int} == diag(A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra space before ==?

for k in -size(S,1):size(S,2)
@test diag(S, k)::SparseVector{T,Int} == diag(A, k)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra space before ==?

end
@test_throws ArgumentError diag(S, -size(S,1)-1)
@test_throws ArgumentError diag(S, size(S,2)+1)
end
end
# test that stored zeros are still stored zeros in the diagonal
S = sparse([1,3],[1,3],[0.0,0.0]); V = diag(S)
@test V.nzind == [1,3]
@test V.nzval == [0.0,0.0]
end

@testset "expandptr" begin
A = speye(5)
@test Base.SparseArrays.expandptr(A.colptr) == collect(1:5)
Expand Down