Skip to content

Commit dca9060

Browse files
committed
optimize for Symmetric,Hermitian,UpperTriangular as well as StridedArray
1 parent 0eea203 commit dca9060

File tree

3 files changed

+62
-24
lines changed

3 files changed

+62
-24
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "SweepOperator"
22
uuid = "7522ee7d-7047-56d0-94d9-4bc626e7058d"
3-
version = "0.3.2"
3+
version = "0.3.3"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/SweepOperator.jl

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,47 @@ function sweep_with_buffer!(akk::AVec{T}, A::AMat{T}, k::Integer, inv::Bool = fa
3939
1 k p || throw(BoundsError(A, k))
4040
p == length(akk) || throw(DimensionError("Incorrect buffer size."))
4141
@inbounds @views begin
42-
d = one(T) / A[k, k] # pivot
43-
copy!(akk, Symmetric(A, :U)[:, k]) # akk = A[:, k]
44-
if A isa StridedMatrix{<:Union{LinearAlgebra.BlasFloat, LinearAlgebra.BlasComplex}}
45-
BLAS.syrk!('U', 'N', -d, akk, one(T), A) # everything not in col/row k
46-
else
47-
A .+= UpperTriangular(-d * akk * akk')
48-
end
49-
rmul!(akk, d * (-one(T)) ^ inv) # akk .* d (negated if inv=true)
50-
copy!(A[1:k-1,k], akk[1:k-1]) # col k
51-
copy!(A[k, k+1:end], akk[k+1:end]) # row k
52-
A[k, k] = -d # pivot element
42+
d = one(T) / A[k, k] # pivot
43+
copy!(akk, Symmetric(A, :U)[:, k]) # akk = A[:, k]
44+
syrk!(A, -d, akk) # everything not in row/col k
45+
rmul!(akk, d * (-one(T)) ^ inv) # akk .* d (negated if inv=true)
46+
setrowcol!(A, k, akk)
47+
A[k, k] = -d # pivot element
5348
end
5449
return A
5550
end
5651

52+
#-----------------------------------------------------------------------------# setrowcol!
53+
# Set upper triangle of: (A[k, :] = x; A[:, k] = x)
54+
function setrowcol!(A::StridedArray, k, x)
55+
@views copy!(A[1:k-1,k], x[1:k-1]) # col k
56+
@views copy!(A[k, k+1:end], x[k+1:end]) # row k
57+
end
58+
59+
setrowcol!(A::Union{Hermitian,Symmetric,UpperTriangular}, k, x) = setrowcol!(A.data, k, x)
60+
61+
#-----------------------------------------------------------------------------# syrk!
62+
const BlasNumber = Union{LinearAlgebra.BlasFloat, LinearAlgebra.BlasComplex}
63+
64+
# In-place update of (the upper triangle of) A + α * x * x'
65+
function syrk!(A::StridedMatrix{T}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber}
66+
BLAS.syrk!('U', 'N', α, x, one(T), A)
67+
end
68+
69+
function syrk!(A::Hermitian{T, S}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber, S<:StridedMatrix{T}}
70+
Hermitian(BLAS.syrk!('U', 'N', α, x, one(T), A.data))
71+
end
72+
73+
function syrk!(A::Symmetric{T, S}, α::T, x::AbstractArray{<:T}) where {T<:BlasNumber, S<:StridedMatrix{T}}
74+
Symmetric(BLAS.syrk!('U', 'N', α, x, one(T), A.data))
75+
end
76+
77+
function syrk!(A, α, x) where {T}
78+
p = checksquare(A)
79+
for i in 1:p, j in i:p
80+
@inbounds A[i,j] += α * x[i] * x[j]
81+
end
82+
end
83+
84+
5785
end # module

test/runtests.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,51 @@ x = randn(n, p)
66
xtx = x'x
77

88
@testset "Sweep One By One" begin
9-
A = deepcopy(xtx)
10-
B = deepcopy(xtx)
9+
A = copy(xtx)
10+
B = copy(xtx)
1111
for j in 1:p
1212
sweep!(A, j)
1313
sweep!(A, j, true)
1414
end
15-
@test A B
15+
@test UpperTriangular(A) UpperTriangular(B)
1616

17-
A = deepcopy(xtx)
18-
B = deepcopy(xtx)
17+
A = copy(xtx)
18+
B = copy(xtx)
1919
for j in 1:p
20-
sweep!(A, j, true)
20+
sweep!(A, j)
2121
end
2222
for j in 1:p
2323
sweep!(A, j, true)
2424
end
25-
@test A B
25+
@test UpperTriangular(A) UpperTriangular(B)
2626
end
2727

2828
@testset "Sweep All" begin
29-
A = deepcopy(xtx)
30-
B = deepcopy(xtx)
29+
A = copy(xtx)
30+
B = copy(xtx)
3131
sweep!(A, 1:p)
3232
sweep!(A, 1:p, true)
3333
@test A B
3434
end
3535

36-
@testset "Non-StridedArray" begin
37-
A = Diagonal(deepcopy(xtx))
38-
B = Diagonal(deepcopy(xtx))
36+
@testset "UpperTriangular" begin
37+
A = UpperTriangular(copy(xtx))
38+
B = UpperTriangular(copy(xtx))
3939
sweep!(A, 1:p)
4040
sweep!(A, 1:p, true)
4141
@test A B
4242
end
4343

44+
@testset "Hermitian/Symmetric" begin
45+
A = Hermitian(copy(xtx))
46+
B = Symmetric(copy(xtx))
47+
sweep!(A, 1:p)
48+
sweep!(A, 1:p, true)
49+
sweep!(B, 1:p)
50+
sweep!(B, 1:p, true)
51+
@test A B xtx
52+
end
53+
4454
@testset "Linear Regression" begin
4555
y = x * collect(1.:p) + randn(n)
4656
xy = [x y]

0 commit comments

Comments
 (0)