Skip to content

Commit 1810952

Browse files
Add the recursive blocked Schur algorithm for matrix square root (#40239)
Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
1 parent da59cdb commit 1810952

File tree

2 files changed

+140
-3
lines changed

2 files changed

+140
-3
lines changed

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,7 +2323,7 @@ sqrt(A::UnitLowerTriangular) = copy(transpose(sqrt(copy(transpose(A)))))
23232323
# Auxiliary functions for matrix square root
23242324

23252325
# square root of upper triangular or real upper quasitriangular matrix
2326-
function sqrt_quasitriu(A0)
2326+
function sqrt_quasitriu(A0; blockwidth = eltype(A0) <: Complex ? 512 : 256)
23272327
n = checksquare(A0)
23282328
T = eltype(A0)
23292329
Tr = typeof(sqrt(real(zero(T))))
@@ -2350,7 +2350,7 @@ function sqrt_quasitriu(A0)
23502350
A = A0
23512351
R = zeros(Tc, n, n)
23522352
end
2353-
_sqrt_quasitriu!(R, A)
2353+
_sqrt_quasitriu!(R, A; blockwidth=blockwidth, n=n)
23542354
Rc = eltype(A0) <: Real ? R : complex(R)
23552355
if A0 isa UpperTriangular
23562356
return UpperTriangular(Rc)
@@ -2361,7 +2361,32 @@ function sqrt_quasitriu(A0)
23612361
end
23622362
end
23632363

2364-
function _sqrt_quasitriu!(R, A)
2364+
# in-place recursive sqrt of upper quasi-triangular matrix A from
2365+
# Deadman E., Higham N.J., Ralha R. (2013) Blocked Schur Algorithms for Computing the Matrix
2366+
# Square Root. Applied Parallel and Scientific Computing. PARA 2012. Lecture Notes in
2367+
# Computer Science, vol 7782. https://doi.org/10.1007/978-3-642-36803-5_12
2368+
function _sqrt_quasitriu!(R, A; blockwidth=64, n=checksquare(A))
2369+
if n blockwidth || !(eltype(R) <: BlasFloat) # base case, perform "point" algorithm
2370+
_sqrt_quasitriu_block!(R, A)
2371+
else # compute blockwise recursion
2372+
split = div(n, 2)
2373+
iszero(A[split+1, split]) || (split += 1) # don't split 2x2 diagonal block
2374+
r1 = 1:split
2375+
r2 = (split + 1):n
2376+
n1, n2 = split, n - split
2377+
A11, A12, A22 = @views A[r1,r1], A[r1,r2], A[r2,r2]
2378+
R11, R12, R22 = @views R[r1,r1], R[r1,r2], R[r2,r2]
2379+
# solve diagonal blocks recursively
2380+
_sqrt_quasitriu!(R11, A11; blockwidth=blockwidth, n=n1)
2381+
_sqrt_quasitriu!(R22, A22; blockwidth=blockwidth, n=n2)
2382+
# solve off-diagonal block
2383+
R12 .= .- A12
2384+
_sylvester_quasitriu!(R11, R22, R12; blockwidth=blockwidth, nA=n1, nB=n2, raise=false)
2385+
end
2386+
return R
2387+
end
2388+
2389+
function _sqrt_quasitriu_block!(R, A)
23652390
_sqrt_quasitriu_diag_block!(R, A)
23662391
_sqrt_quasitriu_offdiag_block!(R, A)
23672392
return R
@@ -2514,6 +2539,83 @@ Base.@propagate_inbounds function _sqrt_quasitriu_offdiag_block_2x2!(R, A, i, j)
25142539
return R
25152540
end
25162541

2542+
# solve Sylvester's equation AX + XB = -C using blockwise recursion until the dimension of
2543+
# A and B are no greater than blockwidth, based on Algorithm 1 from
2544+
# Jonsson I, Kågström B. Recursive blocked algorithms for solving triangular systems—
2545+
# Part I: one-sided and coupled Sylvester-type matrix equations. (2002) ACM Trans Math Softw.
2546+
# 28(4), https://doi.org/10.1145/592843.592845.
2547+
# specify raise=false to avoid breaking the recursion if a LAPACKException is thrown when
2548+
# computing one of the blocks.
2549+
function _sylvester_quasitriu!(A, B, C; blockwidth=64, nA=checksquare(A), nB=checksquare(B), raise=true)
2550+
if 1 nA blockwidth && 1 nB blockwidth
2551+
_sylvester_quasitriu_base!(A, B, C; raise=raise)
2552+
elseif nA 2nB 2
2553+
_sylvester_quasitriu_split1!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
2554+
elseif nB 2nA 2
2555+
_sylvester_quasitriu_split2!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
2556+
else
2557+
_sylvester_quasitriu_splitall!(A, B, C; blockwidth=blockwidth, nA=nA, nB=nB, raise=raise)
2558+
end
2559+
return C
2560+
end
2561+
function _sylvester_quasitriu_base!(A, B, C; raise=true)
2562+
try
2563+
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
2564+
rmul!(C, -inv(scale))
2565+
catch e
2566+
if !(e isa LAPACKException) || raise
2567+
throw(e)
2568+
end
2569+
end
2570+
return C
2571+
end
2572+
function _sylvester_quasitriu_split1!(A, B, C; nA=checksquare(A), kwargs...)
2573+
iA = div(nA, 2)
2574+
iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block
2575+
rA1, rA2 = 1:iA, (iA + 1):nA
2576+
nA1, nA2 = iA, nA-iA
2577+
A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2]
2578+
C1, C2 = @views C[rA1,:], C[rA2,:]
2579+
_sylvester_quasitriu!(A22, B, C2; nA=nA2, kwargs...)
2580+
mul!(C1, A12, C2, true, true)
2581+
_sylvester_quasitriu!(A11, B, C1; nA=nA1, kwargs...)
2582+
return C
2583+
end
2584+
function _sylvester_quasitriu_split2!(A, B, C; nB=checksquare(B), kwargs...)
2585+
iB = div(nB, 2)
2586+
iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block
2587+
rB1, rB2 = 1:iB, (iB + 1):nB
2588+
nB1, nB2 = iB, nB-iB
2589+
B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2]
2590+
C1, C2 = @views C[:,rB1], C[:,rB2]
2591+
_sylvester_quasitriu!(A, B11, C1; nB=nB1, kwargs...)
2592+
mul!(C2, C1, B12, true, true)
2593+
_sylvester_quasitriu!(A, B22, C2; nB=nB2, kwargs...)
2594+
return C
2595+
end
2596+
function _sylvester_quasitriu_splitall!(A, B, C; nA=checksquare(A), nB=checksquare(B), kwargs...)
2597+
iA = div(nA, 2)
2598+
iszero(A[iA + 1, iA]) || (iA += 1) # don't split 2x2 diagonal block
2599+
iB = div(nB, 2)
2600+
iszero(B[iB + 1, iB]) || (iB += 1) # don't split 2x2 diagonal block
2601+
rA1, rA2 = 1:iA, (iA + 1):nA
2602+
nA1, nA2 = iA, nA-iA
2603+
rB1, rB2 = 1:iB, (iB + 1):nB
2604+
nB1, nB2 = iB, nB-iB
2605+
A11, A12, A22 = @views A[rA1,rA1], A[rA1,rA2], A[rA2,rA2]
2606+
B11, B12, B22 = @views B[rB1,rB1], B[rB1,rB2], B[rB2,rB2]
2607+
C11, C21, C12, C22 = @views C[rA1,rB1], C[rA2,rB1], C[rA1,rB2], C[rA2,rB2]
2608+
_sylvester_quasitriu!(A22, B11, C21; nA=nA2, nB=nB1, kwargs...)
2609+
mul!(C11, A12, C21, true, true)
2610+
_sylvester_quasitriu!(A11, B11, C11; nA=nA1, nB=nB1, kwargs...)
2611+
mul!(C22, C21, B12, true, true)
2612+
_sylvester_quasitriu!(A22, B22, C22; nA=nA2, nB=nB2, kwargs...)
2613+
mul!(C12, A12, C22, true, true)
2614+
mul!(C12, C11, B12, true, true)
2615+
_sylvester_quasitriu!(A11, B22, C12; nA=nA1, nB=nB2, kwargs...)
2616+
return C
2617+
end
2618+
25172619
# End of auxiliary functions for matrix square root
25182620

25192621
# Generic eigensystems

stdlib/LinearAlgebra/test/triangular.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,41 @@ Atu = UnitUpperTriangular([1 1 2; 0 1 2; 0 0 1])
513513
@test typeof(sqrt(Atu)[1,1]) <: Real
514514
@test typeof(sqrt(complex(Atu))[1,1]) <: Complex
515515

516+
@testset "matrix square root quasi-triangular blockwise" begin
517+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
518+
A = schur(rand(T, 100, 100)^2).T
519+
@test LinearAlgebra.sqrt_quasitriu(A; blockwidth=16)^2 A
520+
end
521+
n = 256
522+
A = rand(ComplexF64, n, n)
523+
U = schur(A).T
524+
Ubig = Complex{BigFloat}.(U)
525+
@test LinearAlgebra.sqrt_quasitriu(U; blockwidth=64) LinearAlgebra.sqrt_quasitriu(Ubig; blockwidth=64)
526+
end
527+
528+
@testset "sylvester quasi-triangular blockwise" begin
529+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64), m in (15, 40), n in (15, 45)
530+
A = schur(rand(T, m, m)).T
531+
B = schur(rand(T, n, n)).T
532+
C = randn(T, m, n)
533+
Ccopy = copy(C)
534+
X = LinearAlgebra._sylvester_quasitriu!(A, B, C; blockwidth=16)
535+
@test X === C
536+
@test A * X + X * B -Ccopy
537+
538+
@testset "test raise=false does not break recursion" begin
539+
Az = zero(A)
540+
Bz = zero(B)
541+
C2 = copy(Ccopy)
542+
@test_throws LAPACKException LinearAlgebra._sylvester_quasitriu!(Az, Bz, C2; blockwidth=16)
543+
m == n || @test any(C2 .== Ccopy) # recursion broken
544+
C3 = copy(Ccopy)
545+
X3 = LinearAlgebra._sylvester_quasitriu!(Az, Bz, C3; blockwidth=16, raise=false)
546+
@test !any(X3 .== Ccopy) # recursion not broken
547+
end
548+
end
549+
end
550+
516551
@testset "check matrix logarithm type-inferrable" for elty in (Float32,Float64,ComplexF32,ComplexF64)
517552
A = UpperTriangular(exp(triu(randn(elty, n, n))))
518553
@inferred Union{typeof(A),typeof(complex(A))} log(A)

0 commit comments

Comments
 (0)