Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kose-y committed Aug 28, 2024
1 parent f5b810f commit 8c1e036
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/libraries/cublas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ end
@test testf(*, rand(elty, m, n)', rand(elty, m))
x = rand(elty, m)
A = rand(elty, m, m + 1 )
y = rand(elty, m)
y = rand(elty, n)
dx = CuArray(x)
dA = CuArray(A)
dy = CuArray(y)
Expand All @@ -125,7 +125,7 @@ end
hy = collect(dy)
@test hy A * x
dy = CuArray(y)
dx = CUBLAS.gemv('T', alpha, dA, dy)
dx = CUBLAS.gemv(elty <: Real ? 'T' : 'C', alpha, dA, dy)
hx = collect(dx)
@test hx alpha * A' * y
end
Expand Down Expand Up @@ -158,11 +158,11 @@ end
for i=1:length(A)
push!(dy, CuArray(y[i]))
end
CUBLAS.gemv_batched!('T', alpha, dA, dy, beta, dx)
CUBLAS.gemv_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
for i=1:size(A, 3)
hx = collect(dx[:, i])
x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i]
@test x[:, i] hx
hx = collect(dx[i])
x[i] = alpha * A[i]' * y[i] + beta * x[i]
@test x[i] hx
end
end
end
Expand All @@ -188,10 +188,10 @@ end
@test y[:, i] hy
end
dy = CuArray(y)
CUBLAS.gemv_strided_batched!('T', alpha, dA, dy, beta, dx)
CUBLAS.gemv_strided_batched!(elty <: Real ? 'T' : 'C', alpha, dA, dy, beta, dx)
for i=1:size(A, 3)
hx = collect(dx[:, i])
x[:, i] = alpha * transpose(A[:, :, i]) * y[:, i] + beta * y[:, i]
x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i]
@test x[:, i] hx
end
end
Expand Down

0 comments on commit 8c1e036

Please sign in to comment.