-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Add the recursive blocked Schur algorithm for matrix square root #40239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The recursive version is currently applied for matrices with size greater than This benchmark includes |
Does this have a noticeable effect on accuracy? If so, in which direction? |
Good question. Not certain how to determine the answer. The existing test suite passes, and in Section 3 of the original paper they worked out that the blocked algorithms satisfy the same error bounds as the original point algorithm. |
try | ||
_, scale = LAPACK.trsyl!('N', 'N', A, B, C) | ||
rmul!(C, -inv(scale)) | ||
catch e | ||
if !(catcherr && e isa LAPACKException) | ||
throw(e) | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For just about any large real triangular matrix A
created as A=UpperTriangular(randn(n, n))^2
, sqrt(A)^2
fails to be approximately equal to A
(the top right terms explode). This is true for the point algorithm in 1.6 and on master as well. In the blocked version, LAPACK.trsyl!
will throw a LAPACKException
for these matrices, I guess indicating that it could not solve Sylvester's equation. But if that exception is caught, the result seems to be equivalent to what the point algorithm would have returned. Hence why we have this try
/catch
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe a comment to justify ignoring LAPACKException? Should it get a warning when convergence fails?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a comment. RE a warning, perhaps? Not really certain what the conventions are in the stdlib regarding warnings.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While writing the tests for triangular.jl
I realized how poorly conditioned UpperTriangular(randn(n, n))
is. It's an easy way to generate a triangular test matrix but it's often not representative for real world triangular matrices that are typically a result of a matrix factorization. E.g.
julia> mean(cond(lu(randn(100, 100)).U) for i in 1:10000)
2309.908474551785
julia> mean(cond(triu(randn(100, 100))) for i in 1:10000)
2.915729043528958e19
so I generally think it would be better to construct triangular test matrices from an LU.
A separate question is the handling of non-zero info
values in LAPACK (or almost any other mathematical function defined in C/Fortran). I'm not sure we made the right decision when we made all of these throw instead of returning the exit status in some form. However, that is a very big issue to tackle so the your current solution is probably fine (although I generally find try/catch
pretty crude).
That all sounds good. It might be a good idea to run this on a few random matrices of various structure and compare with |
Are you thinking something like this? n = 256
A = rand(ComplexF64, n, n)^2
T = schur(A).T
Tbig = Complex{BigFloat}.(T)
@test LinearAlgebra.sqrt_quasitriu(T) ≈ LinearAlgebra.sqrt_quasitriu(Tbig) We unfortunately cannot check the real quasi-triangular case this way because |
For completeness, here's a benchmark of the blocked recursive Sylvester solver used here vs using LinearAlgebra, BenchmarkTools, Random, Plots
ns = vcat(1, 64, 65, 128, 256, 512, 768, 1024)
rng = MersenneTwister(42)
time_complex = map(ns) do n
A = schur(rand(rng, ComplexF64, n, n)).T
B = schur(rand(rng, ComplexF64, n, n)).T
C = rand(rng, ComplexF64, n, n)
told = @belapsed $(LAPACK.trsyl!)('N', 'N', $A, $B, C) setup=(C=copy($C)) samples=100
tnew = @belapsed $(LinearAlgebra._sylvester_quasitriu!)($A, $B, C) setup=(C=copy($C)) samples=100
@show n, told, tnew
told, tnew
end
time_real = map(ns) do n
A = schur(rand(rng, n, n)).T
B = schur(rand(rng, n, n)).T
C = rand(rng, n, n)
told = @belapsed $(LAPACK.trsyl!)('N', 'N', $A, $B, C) setup=(C=copy($C)) samples=100
tnew = @belapsed $(LinearAlgebra._sylvester_quasitriu!)($A, $B, C) setup=(C=copy($C)) samples=100
@show n, told, tnew
told, tnew
end kwargs = (label=["trsyl!" "recursive"], ylabel="time (s)", xlabel="n", legend=:topleft, yscale=:log10, linewidth=2)
plot(ns, [first.(time_complex) last.(time_complex)]; kwargs...) plot(ns, [first.(time_real) last.(time_real)]; kwargs...) |
@dkarrasch, you would be able to review this by any chance? |
Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
At @RalphAS's suggestion, I verified the correctness of the blocked square root and blocked Sylvester against unblocked using GenericSchur, LinearAlgebra, Test
n = 65
@testset for (T, Tbig) in ((ComplexF64, Complex{BigFloat}),)
Abig = rand(Tbig, n, n)
schurAbig = GenericSchur.gschur(Abig)
sqrtAbig = schurAbig.Z * LinearAlgebra.sqrt_quasitriu(schurAbig.T, blockwidth=Inf) * schurAbig.Z'
A = T.(Abig)
schurA = schur(A)
sqrtA = schurA.Z * LinearAlgebra.sqrt_quasitriu(schurA.T; blockwidth=16) * schurA.Z'
@test sqrtA ≈ sqrtAbig
end
@testset for (T, Tbig) in ((ComplexF64, Complex{BigFloat}),)
Abig = GenericSchur.gschur(rand(Tbig, n, n)).T
Bbig = GenericSchur.gschur(rand(Tbig, n, n)).T
Cbig = rand(Tbig, n, n)
Xbig, scale = GenericSchur.trsylvester!(Abig, -Bbig, -copy(Cbig))
rmul!(Xbig, inv(scale))
A = T.(Abig)
B = T.(Bbig)
C = T.(Cbig)
X = LinearAlgebra._sylvester_quasitriu!(A, B, copy(C); blockwidth=16)
@test X ≈ Xbig
end These tests pass. |
@dkarrasch can you review this? |
Sorry, I don't think I'm competent to review here. |
Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
@andreasnoack, would you by any change be able to review this? Or have a better idea who might? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me provided that all branches are exercised by the tests.
I think one thing that still needs to be resolved here is how to handle #40239 (comment). The size threshold at which the blocked version is used in the paper ( |
Imo, just change the cutoff for now. If we can lower the cutoff later, great, but they shouldn't block the already implimented improvements. |
Alright, I chose a cutoff from the benchmark (256 for |
…iaLang#40239) Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
Fixes JuliaLang/LinearAlgebra.jl#829 by implementing the recursive version of the blocked Schur algorithm for the matrix square root. The speed-ups come from greater use of Level 3 BLAS routines.
It also adds a recursive quasitriangular Sylvester solver, which for large matrices is much faster than
LAPACK.trsyl!
.sylvester
should probably call this function, but that could be a future PR.Benchmark
This benchmark compares the "point" algorithm (what we currently do) with the recursive algorithm (this PR). For a 4000x4000 upper triangular matrix, the recursive algorithm is nearly 2 orders of magnitude faster. This is a greater improvement than in the paper, where they saw an up to 6x speed-up; it seems that this recursive implementation is faster than the paper's.
sqrt(::UpperTriangular{ComplexF64})
sqrt(::UpperTriangular{Float64})