Skip to content

Commit ea14b0c

Browse files
authored
Add tests for polar decomposition for GPU (#83)
1 parent 17ea798 commit ea14b0c

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed

test/amd/polar.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, isposdef, Hermitian
6+
using MatrixAlgebraKit: PolarViaSVD
7+
using AMDGPU
8+
9+
@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
10+
rng = StableRNG(123)
11+
m = 54
12+
@testset "size ($m, $n)" for n in (37, m)
13+
k = min(m, n)
14+
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
15+
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16+
n < m && svd_alg isa ROCSOLVER_QRIteration && continue
17+
A = ROCArray(randn(rng, T, m, n))
18+
alg = PolarViaSVD(svd_alg)
19+
W, P = left_polar(A; alg)
20+
@test W isa ROCMatrix{T} && size(W) == (m, n)
21+
@test P isa ROCMatrix{T} && size(P) == (n, n)
22+
@test W * P A
23+
@test isisometric(W)
24+
# work around extremely strict Julia criteria for Hermiticity
25+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
26+
27+
Ac = similar(A)
28+
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg)
29+
@test W2 === W
30+
@test P2 === P
31+
@test W * P A
32+
@test isisometric(W)
33+
# work around extremely strict Julia criteria for Hermiticity
34+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
35+
36+
noP = similar(P, (0, 0))
37+
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg)
38+
@test P2 === noP
39+
@test W2 === W
40+
@test isisometric(W)
41+
P = W' * A # compute P explicitly to verify W correctness
42+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
43+
@test isposdef(Hermitian(project_hermitian!(P)))
44+
end
45+
end
46+
end
47+
48+
@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
49+
rng = StableRNG(123)
50+
n = 54
51+
@testset "size ($m, $n)" for m in (37, n)
52+
k = min(m, n)
53+
svd_algs = (ROCSOLVER_QRIteration(), ROCSOLVER_Jacobi())
54+
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55+
n > m && svd_alg isa ROCSOLVER_QRIteration && continue
56+
A = ROCArray(randn(rng, T, m, n))
57+
alg = PolarViaSVD(svd_alg)
58+
P, Wᴴ = right_polar(A; alg)
59+
@test Wᴴ isa ROCMatrix{T} && size(Wᴴ) == (m, n)
60+
@test P isa ROCMatrix{T} && size(P) == (m, m)
61+
@test P * Wᴴ A
62+
@test isisometric(Wᴴ; side = :right)
63+
# work around extremely strict Julia criteria for Hermiticity
64+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
65+
66+
Ac = similar(A)
67+
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg)
68+
@test P2 === P
69+
@test Wᴴ2 === Wᴴ
70+
@test P * Wᴴ A
71+
@test isisometric(Wᴴ; side = :right)
72+
# work around extremely strict Julia criteria for Hermiticity
73+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
74+
75+
noP = similar(P, (0, 0))
76+
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
77+
@test P2 === noP
78+
@test Wᴴ2 === Wᴴ
79+
@test isisometric(Wᴴ; side = :right)
80+
P = A * Wᴴ' # compute P explicitly to verify W correctness
81+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
82+
@test isposdef(Hermitian(project_hermitian!(P)))
83+
end
84+
end
85+
end

test/cuda/polar.jl

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
using MatrixAlgebraKit
2+
using Test
3+
using TestExtras
4+
using StableRNGs
5+
using LinearAlgebra: LinearAlgebra, I, isposdef, Hermitian
6+
using MatrixAlgebraKit: PolarViaSVD
7+
using CUDA
8+
9+
@testset "left_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
10+
rng = StableRNG(123)
11+
m = 54
12+
@testset "size ($m, $n)" for n in (37, m)
13+
k = min(m, n)
14+
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
15+
@testset "algorithm $svd_alg" for svd_alg in svd_algs
16+
n < m && svd_alg isa CUSOLVER_QRIteration && continue
17+
A = CuArray(randn(rng, T, m, n))
18+
alg = PolarViaSVD(svd_alg)
19+
W, P = left_polar(A; alg)
20+
@test W isa CuMatrix{T} && size(W) == (m, n)
21+
@test P isa CuMatrix{T} && size(P) == (n, n)
22+
@test W * P A
23+
@test isisometric(W)
24+
# work around extremely strict Julia criteria for Hermiticity
25+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
26+
27+
Ac = similar(A)
28+
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, P), alg)
29+
@test W2 === W
30+
@test P2 === P
31+
@test W * P A
32+
@test isisometric(W)
33+
# work around extremely strict Julia criteria for Hermiticity
34+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
35+
36+
noP = similar(P, (0, 0))
37+
W2, P2 = @constinferred left_polar!(copy!(Ac, A), (W, noP), alg)
38+
@test P2 === noP
39+
@test W2 === W
40+
@test isisometric(W)
41+
P = W' * A # compute P explicitly to verify W correctness
42+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
43+
@test isposdef(Hermitian(project_hermitian!(P)))
44+
end
45+
end
46+
end
47+
48+
@testset "right_polar! for T = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
49+
rng = StableRNG(123)
50+
n = 54
51+
@testset "size ($m, $n)" for m in (37, n)
52+
k = min(m, n)
53+
svd_algs = (CUSOLVER_QRIteration(), CUSOLVER_Jacobi())
54+
@testset "algorithm $svd_alg" for svd_alg in svd_algs
55+
n > m && svd_alg isa CUSOLVER_QRIteration && continue
56+
A = CuArray(randn(rng, T, m, n))
57+
alg = PolarViaSVD(svd_alg)
58+
P, Wᴴ = right_polar(A; alg)
59+
@test Wᴴ isa CuMatrix{T} && size(Wᴴ) == (m, n)
60+
@test P isa CuMatrix{T} && size(P) == (m, m)
61+
@test P * Wᴴ A
62+
@test isisometric(Wᴴ; side = :right)
63+
# work around extremely strict Julia criteria for Hermiticity
64+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
65+
66+
Ac = similar(A)
67+
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (P, Wᴴ), alg)
68+
@test P2 === P
69+
@test Wᴴ2 === Wᴴ
70+
@test P * Wᴴ A
71+
@test isisometric(Wᴴ; side = :right)
72+
# work around extremely strict Julia criteria for Hermiticity
73+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P)) && isposdef(Hermitian(P))
74+
75+
noP = similar(P, (0, 0))
76+
P2, Wᴴ2 = @constinferred right_polar!(copy!(Ac, A), (noP, Wᴴ), alg)
77+
@test P2 === noP
78+
@test Wᴴ2 === Wᴴ
79+
@test isisometric(Wᴴ; side = :right)
80+
P = A * Wᴴ' # compute P explicitly to verify W correctness
81+
@test ishermitian(P; rtol = MatrixAlgebraKit.defaulttol(P))
82+
@test isposdef(Hermitian(project_hermitian!(P)))
83+
end
84+
end
85+
end

test/runtests.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ if CUDA.functional()
7575
@safetestset "CUDA Hermitian Eigenvalue Decomposition" begin
7676
include("cuda/eigh.jl")
7777
end
78+
@safetestset "CUDA Polar Decomposition" begin
79+
include("cuda/polar.jl")
80+
end
7881
end
7982

8083
using AMDGPU
@@ -94,4 +97,7 @@ if AMDGPU.functional()
9497
@safetestset "AMDGPU Hermitian Eigenvalue Decomposition" begin
9598
include("amd/eigh.jl")
9699
end
100+
@safetestset "AMDGPU Polar Decomposition" begin
101+
include("amd/polar.jl")
102+
end
97103
end

0 commit comments

Comments
 (0)