Skip to content

Commit 40ecf69

Browse files
authored
LinearAlgbera: pass sizes to muldiag_size_check (#55378)
This will avoid having to specialize `_muldiag_size_check` on the matrix types, as we only need the sizes (and potentially `ndims`) for the error checks.
1 parent 996351f commit 40ecf69

File tree

2 files changed

+29
-32
lines changed

2 files changed

+29
-32
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ const BiTri = Union{Bidiagonal,Tridiagonal}
461461

462462
# B .= A * B
463463
function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
464-
_muldiag_size_check(A, B)
464+
_muldiag_size_check(size(A), size(B))
465465
(; dv, ev) = A
466466
if A.uplo == 'U'
467467
for k in axes(B,2)
@@ -482,7 +482,7 @@ function lmul!(A::Bidiagonal, B::AbstractVecOrMat)
482482
end
483483
# B .= D * B
484484
function lmul!(D::Diagonal, B::Bidiagonal)
485-
_muldiag_size_check(D, B)
485+
_muldiag_size_check(size(D), size(B))
486486
(; dv, ev) = B
487487
isL = B.uplo == 'L'
488488
dv[1] = D.diag[1] * dv[1]
@@ -494,7 +494,7 @@ function lmul!(D::Diagonal, B::Bidiagonal)
494494
end
495495
# B .= B * A
496496
function rmul!(B::AbstractMatrix, A::Bidiagonal)
497-
_muldiag_size_check(A, B)
497+
_muldiag_size_check(size(A), size(B))
498498
(; dv, ev) = A
499499
if A.uplo == 'U'
500500
for k in reverse(axes(dv,1)[2:end])
@@ -519,7 +519,7 @@ function rmul!(B::AbstractMatrix, A::Bidiagonal)
519519
end
520520
# B .= B * D
521521
function rmul!(B::Bidiagonal, D::Diagonal)
522-
_muldiag_size_check(B, D)
522+
_muldiag_size_check(size(B), size(D))
523523
(; dv, ev) = B
524524
isU = B.uplo == 'U'
525525
dv[1] *= D.diag[1]

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -293,42 +293,39 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
293293
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
294294
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation
295295

296-
function _muldiag_size_check(A, B)
297-
nA = size(A, 2)
298-
mB = size(B, 1)
299-
@noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB"))
300-
@noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB"))
301-
nA == mB || throw_dimerr(B, nA, mB)
296+
function _muldiag_size_check(szA::NTuple{2,Integer}, szB::Tuple{Integer,Vararg{Integer}})
297+
nA = szA[2]
298+
mB = szB[1]
299+
@noinline throw_dimerr(szB::NTuple{2}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of A, $nA, does not match first dimension of B, $mB"))
300+
@noinline throw_dimerr(szB::NTuple{1}, nA, mB) = throw(DimensionMismatch(lazy"second dimension of D, $nA, does not match length of V, $mB"))
301+
nA == mB || throw_dimerr(szB, nA, mB)
302302
return nothing
303303
end
304304
# the output matrix should have the same size as the non-diagonal input matrix or vector
305305
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch(lazy"output matrix has size: $szC, but should have size $szA"))
306-
_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A)
307-
_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A)
308-
_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A)
309-
function _size_check_out(C, A)
310-
szA = size(A)
311-
szC = size(C)
312-
szA == szC || throw_dimerr(szC, szA)
313-
return nothing
306+
function _size_check_out(szC::NTuple{2}, szA::NTuple{2}, szB::NTuple{2})
307+
(szC[1] == szA[1] && szC[2] == szB[2]) || throw_dimerr(szC, (szA[1], szB[2]))
308+
end
309+
function _size_check_out(szC::NTuple{1}, szA::NTuple{2}, szB::NTuple{1})
310+
szC[1] == szA[1] || throw_dimerr(szC, (szA[1],))
314311
end
315-
function _muldiag_size_check(C, A, B)
316-
_muldiag_size_check(A, B)
317-
_size_check_out(C, A, B)
312+
function _muldiag_size_check(szC::Tuple{Vararg{Integer}}, szA::Tuple{Vararg{Integer}}, szB::Tuple{Vararg{Integer}})
313+
_muldiag_size_check(szA, szB)
314+
_size_check_out(szC, szA, szB)
318315
end
319316

320317
function (*)(Da::Diagonal, Db::Diagonal)
321-
_muldiag_size_check(Da, Db)
318+
_muldiag_size_check(size(Da), size(Db))
322319
return Diagonal(Da.diag .* Db.diag)
323320
end
324321

325322
function (*)(D::Diagonal, V::AbstractVector)
326-
_muldiag_size_check(D, V)
323+
_muldiag_size_check(size(D), size(V))
327324
return D.diag .* V
328325
end
329326

330327
function rmul!(A::AbstractMatrix, D::Diagonal)
331-
_muldiag_size_check(A, D)
328+
_muldiag_size_check(size(A), size(D))
332329
for I in CartesianIndices(A)
333330
row, col = Tuple(I)
334331
@inbounds A[row, col] *= D.diag[col]
@@ -337,7 +334,7 @@ function rmul!(A::AbstractMatrix, D::Diagonal)
337334
end
338335
# T .= T * D
339336
function rmul!(T::Tridiagonal, D::Diagonal)
340-
_muldiag_size_check(T, D)
337+
_muldiag_size_check(size(T), size(D))
341338
(; dl, d, du) = T
342339
d[1] *= D.diag[1]
343340
for i in axes(dl,1)
@@ -349,7 +346,7 @@ function rmul!(T::Tridiagonal, D::Diagonal)
349346
end
350347

351348
function lmul!(D::Diagonal, B::AbstractVecOrMat)
352-
_muldiag_size_check(D, B)
349+
_muldiag_size_check(size(D), size(B))
353350
for I in CartesianIndices(B)
354351
row = I[1]
355352
@inbounds B[I] = D.diag[row] * B[I]
@@ -360,7 +357,7 @@ end
360357
# in-place multiplication with a diagonal
361358
# T .= D * T
362359
function lmul!(D::Diagonal, T::Tridiagonal)
363-
_muldiag_size_check(D, T)
360+
_muldiag_size_check(size(D), size(T))
364361
(; dl, d, du) = T
365362
d[1] = D.diag[1] * d[1]
366363
for i in axes(dl,1)
@@ -452,7 +449,7 @@ function __muldiag!(out, D1::Diagonal, D2::Diagonal, _add::MulAddMul{ais1,bis0})
452449
end
453450

454451
function _mul_diag!(out, A, B, _add)
455-
_muldiag_size_check(out, A, B)
452+
_muldiag_size_check(size(out), size(A), size(B))
456453
__muldiag!(out, A, B, _add)
457454
return out
458455
end
@@ -469,14 +466,14 @@ _mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, _add) =
469466
_mul_diag!(C, Da, Db, _add)
470467

471468
function (*)(Da::Diagonal, A::AbstractMatrix, Db::Diagonal)
472-
_muldiag_size_check(Da, A)
473-
_muldiag_size_check(A, Db)
469+
_muldiag_size_check(size(Da), size(A))
470+
_muldiag_size_check(size(A), size(Db))
474471
return broadcast(*, Da.diag, A, permutedims(Db.diag))
475472
end
476473

477474
function (*)(Da::Diagonal, Db::Diagonal, Dc::Diagonal)
478-
_muldiag_size_check(Da, Db)
479-
_muldiag_size_check(Db, Dc)
475+
_muldiag_size_check(size(Da), size(Db))
476+
_muldiag_size_check(size(Db), size(Dc))
480477
return Diagonal(Da.diag .* Db.diag .* Dc.diag)
481478
end
482479

0 commit comments

Comments
 (0)