@@ -276,38 +276,91 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
276
276
lmul! (D, At)
277
277
end
278
278
279
- @inline function __muldiag! (out, D:: Diagonal , B, alpha, beta)
280
- if iszero (beta)
281
- out .= (D. diag .* B) .* ₛ alpha
279
+ # in __muldiag! below we unroll the loops manually, since broadcasting may be unable to
280
+ # prove that they are vectorizable
281
+ function __muldiag! (out, D:: Diagonal , B, alpha, beta)
282
+ # TODO : check if this code can be replaced by a single line
283
+ # out .= (D.diag .* B) .*ₛ alpha .+ out .*ₛ beta
284
+ require_one_based_indexing (out)
285
+ if iszero (alpha)
286
+ _rmul_or_fill! (out, beta)
282
287
else
283
- out .= (D. diag .* B) .* ₛ alpha .+ out .* beta
288
+ if iszero (beta)
289
+ @inbounds for j in axes (B, 2 )
290
+ @simd for i in axes (B, 1 )
291
+ out[i,j] = D. diag[i] * B[i,j] * alpha
292
+ end
293
+ end
294
+ else
295
+ @inbounds for j in axes (B, 2 )
296
+ @simd for i in axes (B, 1 )
297
+ out[i,j] = D. diag[i] * B[i,j] * alpha + out[i,j] * beta
298
+ end
299
+ end
300
+ end
284
301
end
285
302
return out
286
303
end
287
-
288
- @inline function __muldiag! (out, A, D:: Diagonal , alpha, beta)
289
- if iszero (beta)
290
- out .= (A .* permutedims (D. diag)) .* ₛ alpha
304
+ function __muldiag! (out, A, D:: Diagonal , alpha, beta)
305
+ # TODO : check if this code can be replaced by a single line
306
+ # out .= (B .* permutedims(D.diag)) .*ₛ alpha .+ out .*ₛ beta
307
+ require_one_based_indexing (out)
308
+ if iszero (alpha)
309
+ _rmul_or_fill! (out, beta)
291
310
else
292
- out .= (A .* permutedims (D. diag)) .* ₛ alpha .+ out .* beta
311
+ if iszero (beta)
312
+ @inbounds for j in axes (A, 2 )
313
+ dja = D. diag[j] * alpha
314
+ @simd for i in axes (A, 1 )
315
+ out[i,j] = A[i,j] * dja
316
+ end
317
+ end
318
+ else
319
+ @inbounds for j in axes (A, 2 )
320
+ dja = D. diag[j] * alpha
321
+ @simd for i in axes (A, 1 )
322
+ out[i,j] = A[i,j] * dja + out[i,j] * beta
323
+ end
324
+ end
325
+ end
293
326
end
294
327
return out
295
328
end
296
-
297
- @inline function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , alpha, beta)
298
- if iszero (beta)
299
- out. diag .= (D1. diag .* D2. diag) .* ₛ alpha
329
+ function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , alpha, beta)
330
+ # TODO : check if this code can be replaced by a single line
331
+ # out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .*ₛ beta
332
+ d1 = D1. diag
333
+ d2 = D2. diag
334
+ if iszero (alpha)
335
+ _rmul_or_fill! (out. diag, beta)
300
336
else
301
- out. diag .= (D1. diag .* D2. diag) .* ₛ alpha .+ out. diag .* beta
337
+ if iszero (beta)
338
+ @inbounds @simd for i in eachindex (out. diag)
339
+ out. diag[i] = d1[i] * d2[i] * alpha
340
+ end
341
+ else
342
+ @inbounds @simd for i in eachindex (out. diag)
343
+ out. diag[i] = d1[i] * d2[i] * alpha + out. diag[i] * beta
344
+ end
345
+ end
346
+ end
347
+ return out
348
+ end
349
+ function __muldiag! (out, D1:: Diagonal , D2:: Diagonal , alpha, beta)
350
+ require_one_based_indexing (out)
351
+ mA = size (D1, 1 )
352
+ d1 = D1. diag
353
+ d2 = D2. diag
354
+ _rmul_or_fill! (out, beta)
355
+ if ! iszero (alpha)
356
+ @inbounds @simd for i in 1 : mA
357
+ out[i,i] += d1[i] * d2[i] * alpha
358
+ end
302
359
end
303
360
return out
304
361
end
305
362
306
- # only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
307
- @inline __muldiag! (out, D1:: Diagonal , D2:: Diagonal , alpha, beta) =
308
- mul! (out, D1, D2, alpha, beta)
309
-
310
- @inline function _muldiag! (out, A, B, alpha, beta)
363
+ function _muldiag! (out, A, B, alpha, beta)
311
364
_muldiag_size_check (out, A, B)
312
365
__muldiag! (out, A, B, alpha, beta)
313
366
return out
332
385
@inline mul! (C:: Diagonal , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number ) =
333
386
_muldiag! (C, Da, Db, alpha, beta)
334
387
335
- function mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number )
336
- _muldiag_size_check (C, Da, Db)
337
- require_one_based_indexing (C)
338
- mA = size (Da, 1 )
339
- da = Da. diag
340
- db = Db. diag
341
- _rmul_or_fill! (C, beta)
342
- if iszero (beta)
343
- @inbounds @simd for i in 1 : mA
344
- C[i,i] = Ref (da[i] * db[i]) .* ₛ alpha
345
- end
346
- else
347
- @inbounds @simd for i in 1 : mA
348
- C[i,i] += Ref (da[i] * db[i]) .* ₛ alpha
349
- end
350
- end
351
- return C
352
- end
388
+ mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number ) =
389
+ _muldiag! (C, Da, Db, alpha, beta)
353
390
354
391
_init (op, A:: AbstractArray{<:Number} , B:: AbstractArray{<:Number} ) =
355
392
(_ -> zero (typeof (op (oneunit (eltype (A)), oneunit (eltype (B))))))
0 commit comments