diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl index a64c3b3901..c09a456000 100644 --- a/lib/cublas/wrappers.jl +++ b/lib/cublas/wrappers.jl @@ -416,15 +416,16 @@ for (fname, fname_64, eltyin, eltyout) in ( if length(A) != length(x) || length(A) != length(y) throw(DimensionMismatch("Lengths of inputs must be the same")) end + m = size(A[1], 1) + n = size(A[1], 2) for (i, (As,xs,ys)) in enumerate(zip(A,x,y)) - m,n = size(As) + if size(As) != (m, n) + throw(DimensionMismatch("A[$i] has different dimension from A[1]. Dimensions between A's should be identical.")) + end if length(xs) != (trans == 'N' ? n : m) || length(ys) != (trans == 'N' ? m : n) throw(DimensionMismatch("Input $i: A has dimension $(size(As)), x has dimension $(size(xs)), y has dimension $(size(ys))")) end end - - m = size(A[1], 1) - n = size(A[1], 2) lda = max(1,stride(A[1],2)) incx = stride(x[1],1) incy = stride(y[1],1)