Skip to content

Commit 0c620d8

Browse files
Fix performance issue with diagonal multiplication
Co-authored-by: Dilum Aluthge <dilum@aluthge.com>
1 parent 1e64682 commit 0c620d8

File tree

2 files changed

+75
-39
lines changed

2 files changed

+75
-39
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -276,38 +276,91 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
276276
lmul!(D, At)
277277
end
278278

279-
@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
280-
if iszero(beta)
281-
out .= (D.diag .* B) .*ₛ alpha
279+
# in __muldiag! below we unroll the loops manually, since broadcasting may be unable to
280+
# prove that they are vectorizable
281+
function __muldiag!(out, D::Diagonal, B, alpha, beta)
282+
# TODO: check if this code can be replaced by a single line
283+
# out .= (D.diag .* B) .*ₛ alpha .+ out .*ₛ beta
284+
require_one_based_indexing(out)
285+
if iszero(alpha)
286+
_rmul_or_fill!(out, beta)
282287
else
283-
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
288+
if iszero(beta)
289+
@inbounds for j in axes(B, 2)
290+
@simd for i in axes(B, 1)
291+
out[i,j] = D.diag[i] * B[i,j] * alpha
292+
end
293+
end
294+
else
295+
@inbounds for j in axes(B, 2)
296+
@simd for i in axes(B, 1)
297+
out[i,j] = D.diag[i] * B[i,j] * alpha + out[i,j] * beta
298+
end
299+
end
300+
end
284301
end
285302
return out
286303
end
287-
288-
@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
289-
if iszero(beta)
290-
out .= (A .* permutedims(D.diag)) .*ₛ alpha
304+
function __muldiag!(out, A, D::Diagonal, alpha, beta)
305+
# TODO: check if this code can be replaced by a single line
306+
# out .= (B .* permutedims(D.diag)) .*ₛ alpha .+ out .*ₛ beta
307+
require_one_based_indexing(out)
308+
if iszero(alpha)
309+
_rmul_or_fill!(out, beta)
291310
else
292-
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
311+
if iszero(beta)
312+
@inbounds for j in axes(A, 2)
313+
dja = D.diag[j] * alpha
314+
@simd for i in axes(A, 1)
315+
out[i,j] = A[i,j] * dja
316+
end
317+
end
318+
else
319+
@inbounds for j in axes(A, 2)
320+
dja = D.diag[j] * alpha
321+
@simd for i in axes(A, 1)
322+
out[i,j] = A[i,j] * dja + out[i,j] * beta
323+
end
324+
end
325+
end
293326
end
294327
return out
295328
end
296-
297-
@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
298-
if iszero(beta)
299-
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
329+
function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
330+
# TODO: check if this code can be replaced by a single line
331+
# out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .*ₛ beta
332+
d1 = D1.diag
333+
d2 = D2.diag
334+
if iszero(alpha)
335+
_rmul_or_fill!(out.diag, beta)
300336
else
301-
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
337+
if iszero(beta)
338+
@inbounds @simd for i in eachindex(out.diag)
339+
out.diag[i] = d1[i] * d2[i] * alpha
340+
end
341+
else
342+
@inbounds @simd for i in eachindex(out.diag)
343+
out.diag[i] = d1[i] * d2[i] * alpha + out.diag[i] * beta
344+
end
345+
end
346+
end
347+
return out
348+
end
349+
function __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta)
350+
require_one_based_indexing(out)
351+
mA = size(D1, 1)
352+
d1 = D1.diag
353+
d2 = D2.diag
354+
_rmul_or_fill!(out, beta)
355+
if !iszero(alpha)
356+
@inbounds @simd for i in 1:mA
357+
out[i,i] += d1[i] * d2[i] * alpha
358+
end
302359
end
303360
return out
304361
end
305362

306-
# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
307-
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
308-
mul!(out, D1, D2, alpha, beta)
309-
310-
@inline function _muldiag!(out, A, B, alpha, beta)
363+
function _muldiag!(out, A, B, alpha, beta)
311364
_muldiag_size_check(out, A, B)
312365
__muldiag!(out, A, B, alpha, beta)
313366
return out
@@ -332,24 +385,8 @@ end
332385
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
333386
_muldiag!(C, Da, Db, alpha, beta)
334387

335-
function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
336-
_muldiag_size_check(C, Da, Db)
337-
require_one_based_indexing(C)
338-
mA = size(Da, 1)
339-
da = Da.diag
340-
db = Db.diag
341-
_rmul_or_fill!(C, beta)
342-
if iszero(beta)
343-
@inbounds @simd for i in 1:mA
344-
C[i,i] = Ref(da[i] * db[i]) .*ₛ alpha
345-
end
346-
else
347-
@inbounds @simd for i in 1:mA
348-
C[i,i] += Ref(da[i] * db[i]) .*ₛ alpha
349-
end
350-
end
351-
return C
352-
end
388+
mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
389+
_muldiag!(C, Da, Db, alpha, beta)
353390

354391
_init(op, A::AbstractArray{<:Number}, B::AbstractArray{<:Number}) =
355392
(_ -> zero(typeof(op(oneunit(eltype(A)), oneunit(eltype(B))))))

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
# inside this function.
99
function *end
1010
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
11-
iszero(beta::Number) ? false :
12-
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
11+
iszero(beta::Number) ? false : broadcasted(*, out, beta)
1312

1413
"""
1514
MulAddMul(alpha, beta)

0 commit comments

Comments
 (0)