Skip to content

Commit

Permalink
more fix for gemv_batched!
Browse files Browse the repository at this point in the history
all the input dimensions should be identical for gemv_batched!
  • Loading branch information
kose-y authored Aug 28, 2024
1 parent 39c75df commit f5b810f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions lib/cublas/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,16 @@ for (fname, fname_64, eltyin, eltyout) in (
if length(A) != length(x) || length(A) != length(y)
throw(DimensionMismatch("Lengths of inputs must be the same"))
end
m = size(A[1], 1)
n = size(A[1], 2)
for (i, (As,xs,ys)) in enumerate(zip(A,x,y))
m,n = size(As)
if size(As) != (m, n)
throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical."))
end
if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n)
throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))"))
end
end

m = size(A[1], 1)
n = size(A[1], 2)
lda = max(1,stride(A[1],2))
incx = stride(x[1],1)
incy = stride(y[1],1)
Expand Down

0 comments on commit f5b810f

Please sign in to comment.