Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions src/solver/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,61 @@ for (fname, elty) in (
end
end

for (fname, elty) in (
(:rocsolver_sgeblttrf_npvt, :Float32),
(:rocsolver_dgeblttrf_npvt, :Float64),
(:rocsolver_cgeblttrf_npvt, :ComplexF32),
(:rocsolver_zgeblttrf_npvt, :ComplexF64),
)
@eval begin
function geblttrf!(A::ROCArray{$elty,3}, B::ROCArray{$elty,3}, C::ROCArray{$elty,3})
mA, nA, nblocksA = size(A)
mB, nB, nblocksB = size(B)
mC, nC, nblocksC = size(C)
(mA == nA == mB == nB == mC == nC) || throw(DimensionMismatch("The first two dimensions of A, B and C must match"))
(nblocksA == nblocksB - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the last dimension of A, B and C"))

lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldc = max(1, stride(C, 2))

devinfo = ROCArray{Cint}(undef, 1)
$fname(rocBLAS.handle(), mB, nblocksB, A, lda, B, ldb, C, ldc, devinfo)
info = AMDGPU.@allowscalar devinfo[1]
AMDGPU.unsafe_free!(devinfo)
chkargsok(BlasInt(info))
B, C
end
end
end

for (fname, elty) in (
(:rocsolver_sgeblttrs_npvt, :Float32),
(:rocsolver_dgeblttrs_npvt, :Float64),
(:rocsolver_cgeblttrs_npvt, :ComplexF32),
(:rocsolver_zgeblttrs_npvt, :ComplexF64),
)
@eval begin
function geblttrs!(A::ROCArray{$elty,3}, B::ROCArray{$elty,3}, C::ROCArray{$elty,3}, X::ROCArray{$elty,3})
mA, nA, nblocksA = size(A)
mB, nB, nblocksB = size(B)
mC, nC, nblocksC = size(C)
mX, nblocksX, nrhs = size(X)
(mA == nA == mB == nB == mC == nC) || throw(DimensionMismatch("The first two dimensions of A, B and C must match"))
(mX == mA) || throw(DimensionMismatch("The first dimension of X is inconsistent with first two dimensions of A, B and C"))
(nblocksA == nblocksB - 1 == nblocksX - 1 == nblocksC) || throw(DimensionMismatch("Inconsistency for the number of blocks in A, B, C and X"))

lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldc = max(1, stride(C, 2))
ldx = max(1, stride(X, 2))

$fname(rocBLAS.handle(), mB, nblocksB, nrhs, A, lda, B, ldb, C, ldc, X, ldx)
X
end
end
end

for (fname, elty, relty) in ((:rocsolver_sgebrd, :Float32 , :Float32),
(:rocsolver_dgebrd, :Float64 , :Float64),
(:rocsolver_cgebrd, :ComplexF32, :Float32),
Expand Down
75 changes: 75 additions & 0 deletions test/rocarray/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,81 @@ end
end
end

@testset "geblttrf! -- geblttrs!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
@testset "n = $n" for n in (1, ) # 8, 16)
@testset "nblocks = $nblocks" for nblocks in (4, 8, 16, 32)
nrhs = 1
p = n * nblocks
A = rand(elty, n, n, nblocks-1)
B = rand(elty, n, n, nblocks)
C = rand(elty, n, n, nblocks-1)
R = rand(elty, n, nblocks, nrhs)

M = zeros(elty, p, p)
RHS = zeros(elty, p, nrhs)
for k in 1:nblocks
offset = (k-1)*n
for i = 1:n
for j = 1:n
M[offset+i,offset+j] = B[i,j,k]
if k < nblocks
M[offset+n+i,offset+j] = A[i,j,k]
M[offset+i,offset+n+j] = C[i,j,k]
end
end
for j = 1:nrhs
RHS[offset+i,j] = R[i,k,j]
end
end
end

d_A = ROCArray(A)
d_B = ROCArray(B)
d_C = ROCArray(C)
d_R = ROCArray(R)
rocSOLVER.geblttrf!(d_A, d_B, d_C)

L = zeros(elty, p, p)
U = zeros(elty, p, p)
B2 = collect(d_B)
C2 = collect(d_C)
for k in 1:nblocks
offset = (k-1)*n
for i = 1:n
for j = 1:n
if i == j
U[offset+i,offset+j] = one(elty)
end
L[offset+i,offset+j] = B2[i,j,k]
if k < nblocks
L[offset+n+i,offset+j] = A[i,j,k]
U[offset+i,offset+n+j] = C2[i,j,k]
end

end
end
end
N = L * U
@test N ≈ M

X = N \ RHS
Y = similar(R)
for k in 1:nblocks
for i = 1:n
for j = 1:nrhs
l = (k-1)*n + i
Y[i, k, j] = X[l,j]
end
end
end
rocSOLVER.geblttrs!(d_A, d_B, d_C, d_R)
@test Y ≈ collect(d_R)
end
end
end
end

@testset "gebrd!" begin
@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
A = rand(elty,m,n)
Expand Down