Skip to content

Commit

Permalink
implement bandwidths for OneElement (#447)
Browse files Browse the repository at this point in the history
* implement bandwidths for OneElement

* make improvements

* fix sparse(::SparseMatrixCSC)

* fix bandwidths for SparseMatrixCSC, add for SparseVector

* add bandwidths(::Zeros) behaviour for empty sparse structures

* add unit tests

* cleanup bandwidths

* Update interfaceimpl.jl

---------

Co-authored-by: Sheehan Olver <solver@mac.com>
  • Loading branch information
max-vassili3v and dlfivefifty authored Jul 23, 2024
1 parent ea616cc commit a2649ce
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 15 deletions.
40 changes: 28 additions & 12 deletions ext/BandedMatricesSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,53 @@ module BandedMatricesSparseArraysExt

using BandedMatrices
using BandedMatrices: _banded_rowval, _banded_colval, _banded_nzval
using SparseArrays
using SparseArrays, FillArrays
import SparseArrays: sparse

function sparse(B::BandedMatrix)
sparse(_banded_rowval(B), _banded_colval(B), _banded_nzval(B), size(B)...)
end

function BandedMatrices.bandwidths(A::SparseMatrixCSC)
l,u = -size(A,1),-size(A,2)

m,n = size(A)
l = u = -max(size(A,1),size(A,2))
n = size(A)[2]
rows = rowvals(A)
vals = nonzeros(A)

if isempty(vals)
return bandwidths(Zeros(1))
end

for j = 1:n
for ind in nzrange(A, j)
i = rows[ind]
# We skip non-structural zeros when computing the
# bandwidths.
iszero(vals[ind]) && continue
ij = abs(i-j)
if i j
l = max(l, ij)
u = max(u, -ij)
elseif i < j
l = max(l, -ij)
u = max(u, ij)
end
u = max(u, j-i)
l = max(l, i-j)
end
end

l,u
end

#Treat as n x 1 matrix
function BandedMatrices.bandwidths(A::SparseVector)
l = u = -size(A,1)
rows = rowvals(A)

if isempty(rows)
return bandwidths(Zeros(1))
end

for i in rows
iszero(i) && continue
u = max(u, 1-i)
l = max(l, i-1)
end

l,u
end

end
2 changes: 1 addition & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import ArrayLayouts: AbstractTridiagonalLayout, BidiagonalLayout, BlasMatLdivVec
symmetricuplo, transposelayout, triangulardata, triangularlayout, zero!,
QRPackedQLayout, AdjQRPackedQLayout

import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal
import FillArrays: AbstractFill, getindex_value, _broadcasted_zeros, unique_value, OneElement, RectDiagonal, OneElementMatrix, OneElementVector

const libblas = LinearAlgebra.BLAS.libblas
const liblapack = LinearAlgebra.BLAS.liblapack
Expand Down
21 changes: 21 additions & 0 deletions src/interfaceimpl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ bandwidths(::Tridiagonal) = (1,1)
sublayout(::AbstractTridiagonalLayout, ::Type{<:Tuple{AbstractUnitRange{Int},AbstractUnitRange{Int}}}) =
BandedLayout()

#Implement bandwidths for OneElement structure
function bandwidths(o::OneElementVector)
k = FillArrays.nzind(o)[1] # index of non-zero
n = length(o)
if k > n || k < 1
bandwidths(Zeros(o))
else
(k-1, 1-k)
end
end

function bandwidths(o::OneElementMatrix)
n,m = size(o)
k,j = Tuple(FillArrays.nzind(o)) # indices of non-zero entries
if k > n || j > m || k < 1 || j < 1
bandwidths(Zeros(o))
else
(k-j,j-k)
end
end

###
# rot180
###
Expand Down
16 changes: 14 additions & 2 deletions test/test_interface.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module TestInterface

using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test
using BandedMatrices, LinearAlgebra, ArrayLayouts, FillArrays, Test, Random
import BandedMatrices: isbanded, AbstractBandedLayout, BandedStyle,
BandedColumns, bandeddata
import ArrayLayouts: OnesLayout, UnknownLayout
using InfiniteArrays
using InfiniteArrays, SparseArrays

struct PseudoBandedMatrix{T} <: AbstractMatrix{T}
data::Array{T}
Expand Down Expand Up @@ -310,6 +310,18 @@ end
@test layout_getindex(T,1:10,1:10) isa BandedMatrix
end

@testset "OneElement" begin
o = OneElement(1, 3, 5)
@test bandwidths(o) == (2,-2)
n,m = rand(1:10,2)
o = OneElement(1, (rand(1:n),rand(1:m)), (n, m))
@test bandwidths(o) == bandwidths(sparse(o))
o = OneElement(1, (n+1,m+1), (n, m))
@test bandwidths(o) == bandwidths(Zeros(o))
o = OneElement(1, 6, 5)
@test bandwidths(o) == bandwidths(Zeros(o))
end

@testset "rot180" begin
A = brand(5,5,1,2)
R = rot180(A)
Expand Down
9 changes: 9 additions & 0 deletions test/test_miscs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,17 @@ import BandedMatrices: _BandedMatrix, DefaultBandedMatrix
@test bA isa BandedMatrix
@test bA == A
@test bandwidths(bA) == min.((l,u),9)
v = sparsevec(brand(10, 1, l, u))
@test bandwidths(v) == (l, min(0, u))
end

l, u = -1, 0
A = brand(10, 10, l, u)
sA = sparse(A)
@test bandwidths(sA) == bandwidths(Zeros(1))
v = sparsevec(brand(10, 1, l, u))
@test bandwidths(v) == bandwidths(Zeros(1))

for diags = [(-1 => ones(Int, 5),),
(-2 => ones(Int, 5),),
(2 => ones(Int, 5),),
Expand Down

0 comments on commit a2649ce

Please sign in to comment.