From 8c1e0366ff6ff23e6ecbbc750d9451eb1071e1a4 Mon Sep 17 00:00:00 2001 From: Seyoon Ko Date: Wed, 28 Aug 2024 13:05:17 -0700 Subject: [PATCH] fix tests --- test/libraries/cublas.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index 65cef41686..de39b7acc3 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -105,7 +105,7 @@ end @test testf(*, rand(elty, m, n)', rand(elty, m)) x = rand(elty, m) A = rand(elty, m, m + 1 ) - y = rand(elty, m) + y = rand(elty, n) dx = CuArray(x) dA = CuArray(A) dy = CuArray(y) @@ -125,7 +125,7 @@ end hy = collect(dy) @test hy ≈ A * x dy = CuArray(y) - dx = CUBLAS.gemv('T', alpha, dA, dy) + dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy) hx = collect(dx) @test hx ≈ alpha * A' * y end @@ -158,11 +158,11 @@ end for i=1:length(A) push!(dy, CuArray(y[i])) end - CUBLAS.gemv_batched!('T', alpha, dA, dy, beta, dx) + CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) for i=1:size(A, 3) - hx = collect(dx[:, i]) - x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] - @test x[:, i] ≈ hx + hx = collect(dx[i]) + x[i] = alpha * A[i]' * y[i] + beta * x[i] + @test x[i] ≈ hx end end end @@ -188,10 +188,10 @@ end @test y[:, i] ≈ hy end dy = CuArray(y) - CUBLAS.gemv_strided_batched!('T', alpha, dA, dy, beta, dx) + CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx) for i=1:size(A, 3) hx = collect(dx[:, i]) - x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i] + x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i] @test x[:, i] ≈ hx end end