@@ -293,42 +293,39 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
293293 Diagonal (Base. literal_pow .(^ , D. diag, valp)) # for speed
294294Base. literal_pow (:: typeof (^ ), D:: Diagonal , :: Val{-1} ) = inv (D) # for disambiguation
295295
296- function _muldiag_size_check (A, B )
297- nA = size (A, 2 )
298- mB = size (B, 1 )
299- @noinline throw_dimerr (:: AbstractMatrix , nA, mB) = throw (DimensionMismatch (lazy " second dimension of A, $nA, does not match first dimension of B, $mB" ))
300- @noinline throw_dimerr (:: AbstractVector , nA, mB) = throw (DimensionMismatch (lazy " second dimension of D, $nA, does not match length of V, $mB" ))
301- nA == mB || throw_dimerr (B , nA, mB)
296+ function _muldiag_size_check (szA :: NTuple{2,Integer} , szB :: Tuple{Integer,Vararg{Integer}} )
297+ nA = szA[ 2 ]
298+ mB = szB[ 1 ]
299+ @noinline throw_dimerr (szB :: NTuple{2} , nA, mB) = throw (DimensionMismatch (lazy " second dimension of A, $nA, does not match first dimension of B, $mB" ))
300+ @noinline throw_dimerr (szB :: NTuple{1} , nA, mB) = throw (DimensionMismatch (lazy " second dimension of D, $nA, does not match length of V, $mB" ))
301+ nA == mB || throw_dimerr (szB , nA, mB)
302302 return nothing
303303end
304304# the output matrix should have the same size as the non-diagonal input matrix or vector
305305@noinline throw_dimerr (szC, szA) = throw (DimensionMismatch (lazy " output matrix has size: $szC, but should have size $szA" ))
306- _size_check_out (C, :: Diagonal , A) = _size_check_out (C, A)
307- _size_check_out (C, A, :: Diagonal ) = _size_check_out (C, A)
308- _size_check_out (C, A:: Diagonal , :: Diagonal ) = _size_check_out (C, A)
309- function _size_check_out (C, A)
310- szA = size (A)
311- szC = size (C)
312- szA == szC || throw_dimerr (szC, szA)
313- return nothing
306+ function _size_check_out (szC:: NTuple{2} , szA:: NTuple{2} , szB:: NTuple{2} )
307+ (szC[1 ] == szA[1 ] && szC[2 ] == szB[2 ]) || throw_dimerr (szC, (szA[1 ], szB[2 ]))
308+ end
309+ function _size_check_out (szC:: NTuple{1} , szA:: NTuple{2} , szB:: NTuple{1} )
310+ szC[1 ] == szA[1 ] || throw_dimerr (szC, (szA[1 ],))
314311end
315- function _muldiag_size_check (C, A, B )
316- _muldiag_size_check (A, B )
317- _size_check_out (C, A, B )
312+ function _muldiag_size_check (szC :: Tuple{Vararg{Integer}} , szA :: Tuple{Vararg{Integer}} , szB :: Tuple{Vararg{Integer}} )
313+ _muldiag_size_check (szA, szB )
314+ _size_check_out (szC, szA, szB )
318315end
319316
320317function (* )(Da:: Diagonal , Db:: Diagonal )
321- _muldiag_size_check (Da, Db )
318+ _muldiag_size_check (size (Da), size (Db) )
322319 return Diagonal (Da. diag .* Db. diag)
323320end
324321
325322function (* )(D:: Diagonal , V:: AbstractVector )
326- _muldiag_size_check (D, V )
323+ _muldiag_size_check (size (D), size (V) )
327324 return D. diag .* V
328325end
329326
330327function rmul! (A:: AbstractMatrix , D:: Diagonal )
331- _muldiag_size_check (A, D )
328+ _muldiag_size_check (size (A), size (D) )
332329 for I in CartesianIndices (A)
333330 row, col = Tuple (I)
334331 @inbounds A[row, col] *= D. diag[col]
@@ -337,7 +334,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
337334end
338335# T .= T * D
339336function rmul! (T:: Tridiagonal , D:: Diagonal )
340- _muldiag_size_check (T, D )
337+ _muldiag_size_check (size (T), size (D) )
341338 (; dl, d, du) = T
342339 d[1 ] *= D. diag[1 ]
343340 for i in axes (dl,1 )
@@ -349,7 +346,7 @@ function rmul!(T::Tridiagonal, D::Diagonal)
349346end
350347
351348function lmul! (D:: Diagonal , B:: AbstractVecOrMat )
352- _muldiag_size_check (D, B )
349+ _muldiag_size_check (size (D), size (B) )
353350 for I in CartesianIndices (B)
354351 row = I[1 ]
355352 @inbounds B[I] = D. diag[row] * B[I]
360357# in-place multiplication with a diagonal
361358# T .= D * T
362359function lmul! (D:: Diagonal , T:: Tridiagonal )
363- _muldiag_size_check (D, T )
360+ _muldiag_size_check (size (D), size (T) )
364361 (; dl, d, du) = T
365362 d[1 ] = D. diag[1 ] * d[1 ]
366363 for i in axes (dl,1 )
@@ -452,7 +449,7 @@ function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0})
452449end
453450
454451function _mul_diag! (out, A, B, _add)
455- _muldiag_size_check (out, A, B )
452+ _muldiag_size_check (size ( out), size (A), size (B) )
456453 __muldiag! (out, A, B, _add)
457454 return out
458455end
@@ -469,14 +466,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
469466 _mul_diag! (C, Da, Db, _add)
470467
471468function (* )(Da:: Diagonal , A:: AbstractMatrix , Db:: Diagonal )
472- _muldiag_size_check (Da, A )
473- _muldiag_size_check (A, Db )
469+ _muldiag_size_check (size (Da), size (A) )
470+ _muldiag_size_check (size (A), size (Db) )
474471 return broadcast (* , Da. diag, A, permutedims (Db. diag))
475472end
476473
477474function (* )(Da:: Diagonal , Db:: Diagonal , Dc:: Diagonal )
478- _muldiag_size_check (Da, Db )
479- _muldiag_size_check (Db, Dc )
475+ _muldiag_size_check (size (Da), size (Db) )
476+ _muldiag_size_check (size (Db), size (Dc) )
480477 return Diagonal (Da. diag .* Db. diag .* Dc. diag)
481478end
482479
0 commit comments