Skip to content

Add the recursive blocked Schur algorithm for matrix square root #40239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 1, 2021

Conversation

sethaxen
Copy link
Contributor

@sethaxen sethaxen commented Mar 27, 2021

Fixes JuliaLang/LinearAlgebra.jl#829 by implementing the recursive version of the blocked Schur algorithm for the matrix square root. The speed-ups come from greater use of Level 3 BLAS routines.
It also adds a recursive quasitriangular Sylvester solver, which for large matrices is much faster than LAPACK.trsyl!. sylvester should probably call this function, but that could be a future PR.

Benchmark

This benchmark compares the "point" algorithm (what we currently do) with the recursive algorithm (this PR). For a 4000x4000 upper triangular matrix, the recursive algorithm is nearly 2 orders of magnitude faster. This is a greater improvement than in the paper, where they saw an up to 6x speed-up; it seems that this recursive implementation is faster than the paper's.

using LinearAlgebra, BenchmarkTools, Random

ns =  vcat(1, 250:250:4000)

rng = MersenneTwister(42)
time_complex = map(ns) do n
    A = UpperTriangular(rand(rng, ComplexF64, n, n))
    if n  1000
        t = @belapsed sqrt($A)
    else
        t = @elapsed sqrt(A)
    end
    @show n, t
    t
end

time_real = map(ns) do n
    A = UpperTriangular(rand(rng, n, n))^2
    if n  1000
        t = @belapsed sqrt($A)
    else
        t = @elapsed sqrt(A)
    end
    @show n, t
    t
end
using Plots
kwargs = (label=["point" "recursive"], ylabel="time (s)", xlabel="n", legend=:topleft, yscale=:log10, linewidth=2)
plot(ns, [time_complex_point time_complex_recur]; kwargs...)
plot(ns, [time_real_point time_real_recur]; kwargs...)

sqrt(::UpperTriangular{ComplexF64})
sqrt_complex

sqrt(::UpperTriangular{Float64})
sqrt_real

@sethaxen
Copy link
Contributor Author

sethaxen commented Mar 27, 2021

The recursive version is currently applied for matrices with size greater than (64, 64). These plots shows that the point version is still faster at size (65, 65), so unless we can speed up the recursive version, it probably should only be used for matrices with size greater or equal to (512,512).

This benchmark includes n in [64, 65, 128] (the dashed line marks n=65):

sqrt(::UpperTriangular{ComplexF64})
ratio_complex

sqrt(::UpperTriangular{Float64})
ratio_real

@dkarrasch dkarrasch added linear algebra Linear algebra performance Must go faster labels Mar 28, 2021
@sethaxen sethaxen closed this Apr 4, 2021
@sethaxen sethaxen reopened this Apr 4, 2021
@oscardssmith
Copy link
Member

Does this have a noticeable effect on accuracy? If so, in which direction?

@sethaxen
Copy link
Contributor Author

sethaxen commented Apr 4, 2021

Does this have a noticeable effect on accuracy? If so, in which direction?

Good question. Not certain how to determine the answer. The existing test suite passes, and in Section 3 of the original paper they worked out that the blocked algorithms satisfy the same error bounds as the original point algorithm.

Comment on lines 2562 to 2569
try
_, scale = LAPACK.trsyl!('N', 'N', A, B, C)
rmul!(C, -inv(scale))
catch e
if !(catcherr && e isa LAPACKException)
throw(e)
end
end
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For just about any large real triangular matrix A created as A=UpperTriangular(randn(n, n))^2, sqrt(A)^2 fails to be approximately equal to A (the top right terms explode). This is true for the point algorithm in 1.6 and on master as well. In the blocked version, LAPACK.trsyl! will throw a LAPACKException for these matrices, I guess indicating that it could not solve Sylvester's equation. But if that exception is caught, the result seems to be equivalent to what the point algorithm would have returned. Hence why we have this try/catch.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a comment to justify ignoring LAPACKException? Should it get a warning when convergence fails?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment. RE a warning, perhaps? Not really certain what the conventions are in the stdlib regarding warnings.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While writing the tests for triangular.jl I realized how poorly conditioned UpperTriangular(randn(n, n)) is. It's an easy way to generate a triangular test matrix but it's often not representative for real world triangular matrices that are typically a result of a matrix factorization. E.g.

julia> mean(cond(lu(randn(100, 100)).U) for i in 1:10000)
2309.908474551785

julia> mean(cond(triu(randn(100, 100))) for i in 1:10000)
2.915729043528958e19

so I generally think it would be better to construct triangular test matrices from an LU.

A separate question is the handling of non-zero info values in LAPACK (or almost any other mathematical function defined in C/Fortran). I'm not sure we made the right decision when we made all of these throw instead of returning the exit status in some form. However, that is a very big issue to tackle so the your current solution is probably fine (although I generally find try/catch pretty crude).

@oscardssmith
Copy link
Member

That all sounds good. It might be a good idea to run this on a few random matrices of various structure and compare with BigFloat to confirm.

@sethaxen
Copy link
Contributor Author

sethaxen commented Apr 4, 2021

That all sounds good. It might be a good idea to run this on a few random matrices of various structure and compare with BigFloat to confirm.

Are you thinking something like this?

n = 256
A = rand(ComplexF64, n, n)^2
T = schur(A).T
Tbig = Complex{BigFloat}.(T)
@test LinearAlgebra.sqrt_quasitriu(T)  LinearAlgebra.sqrt_quasitriu(Tbig)

We unfortunately cannot check the real quasi-triangular case this way because schur and LAPACK.trsyl! are not implemented for BigFloat eltypes. So I think we could only test real and complex upper-triangular matrices this way.

@sethaxen
Copy link
Contributor Author

sethaxen commented Apr 9, 2021

For completeness, here's a benchmark of the blocked recursive Sylvester solver used here vs LAPACK.trsyl!:

using LinearAlgebra, BenchmarkTools, Random, Plots

ns =  vcat(1, 64, 65, 128, 256, 512, 768, 1024)

rng = MersenneTwister(42)
time_complex = map(ns) do n
    A = schur(rand(rng, ComplexF64, n, n)).T
    B = schur(rand(rng, ComplexF64, n, n)).T
    C = rand(rng, ComplexF64, n, n)
    told = @belapsed $(LAPACK.trsyl!)('N', 'N', $A, $B, C) setup=(C=copy($C)) samples=100
    tnew = @belapsed $(LinearAlgebra._sylvester_quasitriu!)($A, $B, C) setup=(C=copy($C)) samples=100
    @show n, told, tnew
    told, tnew
end

time_real = map(ns) do n
    A = schur(rand(rng, n, n)).T
    B = schur(rand(rng, n, n)).T
    C = rand(rng, n, n)
    told = @belapsed $(LAPACK.trsyl!)('N', 'N', $A, $B, C) setup=(C=copy($C)) samples=100
    tnew = @belapsed $(LinearAlgebra._sylvester_quasitriu!)($A, $B, C) setup=(C=copy($C)) samples=100
    @show n, told, tnew
    told, tnew
end
kwargs = (label=["trsyl!" "recursive"], ylabel="time (s)", xlabel="n", legend=:topleft, yscale=:log10, linewidth=2)
plot(ns, [first.(time_complex) last.(time_complex)]; kwargs...)

syl_complex

plot(ns, [first.(time_real) last.(time_real)]; kwargs...)

syl_real

@StefanKarpinski
Copy link
Member

@dkarrasch, you would be able to review this by any chance?

@sethaxen
Copy link
Contributor Author

At @RalphAS's suggestion, I verified the correctness of the blocked square root and blocked Sylvester against unblocked BigFloat versions using GenericSchur.jl's Schur decomposition and upper triangular Sylvester solver:

using GenericSchur, LinearAlgebra, Test
n = 65

@testset for (T, Tbig) in ((ComplexF64, Complex{BigFloat}),)
    Abig = rand(Tbig, n, n)
    schurAbig = GenericSchur.gschur(Abig)
    sqrtAbig = schurAbig.Z * LinearAlgebra.sqrt_quasitriu(schurAbig.T, blockwidth=Inf) * schurAbig.Z'

    A = T.(Abig)
    schurA = schur(A)
    sqrtA = schurA.Z * LinearAlgebra.sqrt_quasitriu(schurA.T; blockwidth=16) * schurA.Z'

    @test sqrtA  sqrtAbig
end

@testset for (T, Tbig) in ((ComplexF64, Complex{BigFloat}),)
    Abig = GenericSchur.gschur(rand(Tbig, n, n)).T
    Bbig = GenericSchur.gschur(rand(Tbig, n, n)).T
    Cbig = rand(Tbig, n, n)
    Xbig, scale = GenericSchur.trsylvester!(Abig, -Bbig, -copy(Cbig))
    rmul!(Xbig, inv(scale))

    A = T.(Abig)
    B = T.(Bbig)
    C = T.(Cbig)
    X = LinearAlgebra._sylvester_quasitriu!(A, B, copy(C); blockwidth=16)

    @test X  Xbig
end

These tests pass.

@sethaxen
Copy link
Contributor Author

sethaxen commented May 3, 2021

@dkarrasch can you review this?

@dkarrasch
Copy link
Member

@dkarrasch can you review this?

Sorry, I don't think I'm competent to review here.

sethaxen and others added 2 commits May 3, 2021 16:58
Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
@StefanKarpinski
Copy link
Member

@andreasnoack, would you by any change be able to review this? Or have a better idea who might?

Copy link
Member

@andreasnoack andreasnoack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me provided that all branches are exercised by the tests.

@sethaxen
Copy link
Contributor Author

I think one thing that still needs to be resolved here is how to handle #40239 (comment). The size threshold at which the blocked version is used in the paper (n=64) is quite a bit too low in this case. I was wondering if there was room for improvement so our cutoff would be similar to the paper's. If not, we should probably increase the threshold to something like n=256 or n=512.

@oscardssmith
Copy link
Member

oscardssmith commented Jun 16, 2021

Imo, just change the cutoff for now. If we can lower the cutoff later, great, but they shouldn't block the already implimented improvements.

@sethaxen
Copy link
Contributor Author

Alright, I chose a cutoff from the benchmark (256 for Real eltypes and 512 for Complex).

@oscardssmith oscardssmith added the merge me PR is reviewed. Merge when all tests are passing label Jun 22, 2021
@dkarrasch dkarrasch merged commit 1810952 into JuliaLang:master Jul 1, 2021
@oscardssmith oscardssmith removed the merge me PR is reviewed. Merge when all tests are passing label Jul 1, 2021
@sethaxen sethaxen deleted the sqrtblock branch July 1, 2021 19:34
johanmon pushed a commit to johanmon/julia that referenced this pull request Jul 5, 2021
…iaLang#40239)

Co-authored-by: Mathieu Besançon <mathieu.besancon@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
linear algebra Linear algebra performance Must go faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use blocked Schur algorithm for matrix square root
6 participants