Skip to content

Add the recursive blocked Schur algorithm for matrix square root #40239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 105 additions & 3 deletions stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2314,7 +2314,7 @@ sqrt(A::UnitLowerTriangular) = copy(transpose(sqrt(copy(transpose(A)))))
# Auxiliary functions for matrix square root

# square root of upper triangular or real upper quasitriangular matrix
function sqrt_quasitriu(A0)
function sqrt_quasitriu(A0; blockwidth = eltype(A0) <: Complex ? 512 : 256)
n = checksquare(A0)
T = eltype(A0)
Tr = typeof(sqrt(real(zero(T))))
Expand All @@ -2341,7 +2341,7 @@ function sqrt_quasitriu(A0)
A = A0
R = zeros(Tc, n, n)
end
_sqrt_quasitriu!(R, A)
_sqrt_quasitriu!(R, A; blockwidth=blockwidth, n=n)
Rc = eltype(A0) <: Real ? R : complex(R)
if A0 isa UpperTriangular
return UpperTriangular(Rc)
Expand All @@ -2352,7 +2352,32 @@ function sqrt_quasitriu(A0)
end
end

function _sqrt_quasitriu!(R, A)
# in-place recursive sqrt of upper quasi-triangular matrix A from
# Deadman E., Higham N.J., Ralha R. (2013) Blocked Schur Algorithms for Computing the Matrix
# Square Root. Applied Parallel and Scientific Computing. PARA 2012. Lecture Notes in
# Computer Science, vol 7782. https://doi.org/10.1007/978-3-642-36803-5_12
function _sqrt_quasitriu!(R, A; blockwidth=64, n=checksquare(A))
if n ≤ blockwidth || !(eltype(R) <: BlasFloat) # base case, perform "point" algorithm
_sqrt_quasitriu_block!(R, A)
else # compute blockwise recursion
split = div(n, 2)
iszero(A[split+1, split]) || (split += 1) # don't split 2x2 diagonal block
r1 = 1:split
r2 = (split + 1):n
n1, n2 = split, n - split
A11, A12, A22 = @views A[r1,r1], A[r1,r2], A[r2,r2]
R11, R12, R22 = @views R[r1,r1], R[r1,r2], R[r2,r2]
# solve diagonal blocks recursively
_sqrt_quasitriu!(R11, A11; blockwidth=blockwidth, n=n1)
_sqrt_quasitriu!(R22, A22; blockwidth=blockwidth, n=n2)
# solve off-diagonal block
R12 .= .- A12
_sylvester_quasitriu!(R11, R22, R12; blockwidth=blockwidth, nA=n1, nB=n2, raise=false)
end
return R
end

function _sqrt_quasitriu_block!(R, A)
_sqrt_quasitriu_diag_block!(R, A)
_sqrt_quasitriu_offdiag_block!(R, A)
return R
Expand Down Expand Up @@ -2505,6 +2530,83 @@ Base.@propagate_inbounds function _sqrt_quasitriu_offdiag_block_2x2!(R, A, i, j)
return R
end

# solve Sylvester's equation AX + XB = -C using blockwise recursion until the dimension of
# A and B are no greater than blockwidth, based on Algorithm 1 from
# Jonsson I, Kågström B. Recursive blocked algorithms for solving triangular systems—
# Part I: one-sided and coupled Sylvester-type matrix equations. (2002) ACM Trans Math Softw.
# 28(4), https://doi.org/10.1145/592843.592845.
# specify raise=false to avoid breaking the recursion if a LAPACKException is thrown when
# computing one of the blocks.
function _sylvester_quasitriu!(A, B, C; blockwidth=64, nA=checksquare(A), nB=checksquare(B), raise=true)
if 1 ≤ nA ≤ blockwidth && 1 ≤ nB ≤ blockwidth
_sylvester_quasitriu_base!(A, B, C; raise=raise)
elseif nA ≥ 2nB ≥ 2
_sylvester_quasitriu_split1!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
elseif nB ≥ 2nA ≥ 2
_sylvester_quasitriu_split2!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
else
_sylvester_quasitriu_splitall!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
end
return C
end
function _sylvester_quasitriu_base!(A, B, C; raise=true)
try
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
rmul!(C, -inv(scale))
catch e
if !(e isa LAPACKException) || raise
throw(e)
end
end
return C
end
function _sylvester_quasitriu_split1!(A, B, C; nA=checksquare(A), kwargs...)
iA = div(nA, 2)
iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block
rA1, rA2 = 1:iA, (iA + 1):nA
nA1, nA2 = iA, nA-iA
A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2]
C1, C2 = @views C[rA1,:], C[rA2,:]
_sylvester_quasitriu!(A22, B, C2; nA=nA2, kwargs...)
mul!(C1, A12, C2, true, true)
_sylvester_quasitriu!(A11, B, C1; nA=nA1, kwargs...)
return C
end
function _sylvester_quasitriu_split2!(A, B, C; nB=checksquare(B), kwargs...)
iB = div(nB, 2)
iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block
rB1, rB2 = 1:iB, (iB + 1):nB
nB1, nB2 = iB, nB-iB
B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2]
C1, C2 = @views C[:,rB1], C[:,rB2]
_sylvester_quasitriu!(A, B11, C1; nB=nB1, kwargs...)
mul!(C2, C1, B12, true, true)
_sylvester_quasitriu!(A, B22, C2; nB=nB2, kwargs...)
return C
end
function _sylvester_quasitriu_splitall!(A, B, C; nA=checksquare(A), nB=checksquare(B), kwargs...)
iA = div(nA, 2)
iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block
iB = div(nB, 2)
iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block
rA1, rA2 = 1:iA, (iA + 1):nA
nA1, nA2 = iA, nA-iA
rB1, rB2 = 1:iB, (iB + 1):nB
nB1, nB2 = iB, nB-iB
A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2]
B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2]
C11, C21, C12, C22 = @views C[rA1,rB1], C[rA2,rB1], C[rA1,rB2], C[rA2,rB2]
_sylvester_quasitriu!(A22, B11, C21; nA=nA2, nB=nB1, kwargs...)
mul!(C11, A12, C21, true, true)
_sylvester_quasitriu!(A11, B11, C11; nA=nA1, nB=nB1, kwargs...)
mul!(C22, C21, B12, true, true)
_sylvester_quasitriu!(A22, B22, C22; nA=nA2, nB=nB2, kwargs...)
mul!(C12, A12, C22, true, true)
mul!(C12, C11, B12, true, true)
_sylvester_quasitriu!(A11, B22, C12; nA=nA1, nB=nB2, kwargs...)
return C
end

# End of auxiliary functions for matrix square root

# Generic eigensystems
Expand Down
35 changes: 35 additions & 0 deletions stdlib/LinearAlgebra/test/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,41 @@ Atu = UnitUpperTriangular([1 1 2; 0 1 2; 0 0 1])
@test typeof(sqrt(Atu)[1,1]) <: Real
@test typeof(sqrt(complex(Atu))[1,1]) <: Complex

@testset "matrix square root quasi-triangular blockwise" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
A = schur(rand(T, 100, 100)^2).T
@test LinearAlgebra.sqrt_quasitriu(A; blockwidth=16)^2 ≈ A
end
n = 256
A = rand(ComplexF64, n, n)
U = schur(A).T
Ubig = Complex{BigFloat}.(U)
@test LinearAlgebra.sqrt_quasitriu(U; blockwidth=64) ≈ LinearAlgebra.sqrt_quasitriu(Ubig; blockwidth=64)
end

@testset "sylvester quasi-triangular blockwise" begin
@testset for T in (Float32, Float64, ComplexF32, ComplexF64), m in (15, 40), n in (15, 45)
A = schur(rand(T, m, m)).T
B = schur(rand(T, n, n)).T
C = randn(T, m, n)
Ccopy = copy(C)
X = LinearAlgebra._sylvester_quasitriu!(A, B, C; blockwidth=16)
@test X === C
@test A * X + X * B ≈ -Ccopy

@testset "test raise=false does not break recursion" begin
Az = zero(A)
Bz = zero(B)
C2 = copy(Ccopy)
@test_throws LAPACKException LinearAlgebra._sylvester_quasitriu!(Az, Bz, C2; blockwidth=16)
m == n || @test any(C2 .== Ccopy) # recursion broken
C3 = copy(Ccopy)
X3 = LinearAlgebra._sylvester_quasitriu!(Az, Bz, C3; blockwidth=16, raise=false)
@test !any(X3 .== Ccopy) # recursion not broken
end
end
end

@testset "check matrix logarithm type-inferrable" for elty in (Float32,Float64,ComplexF32,ComplexF64)
A = UpperTriangular(exp(triu(randn(elty, n, n))))
@inferred Union{typeof(A),typeof(complex(A))} log(A)
Expand Down