Skip to content

Commit 21b640b

Browse files
dkarraschKristofferC
authored andcommitted
minor fixes in multiplication with Diagonals (#31443)
* minor fixes in multiplication with Diagonals * correct rmul!(A,D), revert changes in AdjTrans(x)*D * [r/l]mul!: replace conj by adjoint, add transpose * add tests * fix typo * relax some tests, added more tests * simplify tests, strict equality (cherry picked from commit a93185f)
1 parent bf854e1 commit 21b640b

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ end
172172

173173
function rmul!(A::AbstractMatrix, D::Diagonal)
174174
@assert !has_offset_axes(A)
175-
A .= A .* transpose(D.diag)
175+
A .= A .* permutedims(D.diag)
176176
return A
177177
end
178178

@@ -260,20 +260,20 @@ lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)
260260

261261
function lmul!(adjA::Adjoint{<:Any,<:Diagonal}, B::AbstractMatrix)
262262
A = adjA.parent
263-
return lmul!(conj(A.diag), B)
263+
return lmul!(adjoint(A), B)
264264
end
265265
function lmul!(transA::Transpose{<:Any,<:Diagonal}, B::AbstractMatrix)
266266
A = transA.parent
267-
return lmul!(A.diag, B)
267+
return lmul!(transpose(A), B)
268268
end
269269

270270
function rmul!(A::AbstractMatrix, adjB::Adjoint{<:Any,<:Diagonal})
271271
B = adjB.parent
272-
return rmul!(A, conj(B.diag))
272+
return rmul!(A, adjoint(B))
273273
end
274274
function rmul!(A::AbstractMatrix, transB::Transpose{<:Any,<:Diagonal})
275275
B = transB.parent
276-
return rmul!(A, B.diag)
276+
return rmul!(A, transpose(B))
277277
end
278278

279279
# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@@ -552,10 +552,9 @@ end
552552
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
553553
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
554554
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
555-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x)))
555+
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
556556
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
557557
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
558-
# TODO: these methods will yield row matrices, rather than adjoint/transpose vectors
559558

560559
function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
561560
info = 0

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,20 @@ end
461461
fullBB = copyto!(Matrix{Matrix{T}}(undef, 2, 2), BB)
462462
for (transform1, transform2) in ((identity, identity),
463463
(identity, adjoint ), (adjoint, identity ), (adjoint, adjoint ),
464-
(identity, transpose), (transpose, identity ), (transpose, transpose) )
464+
(identity, transpose), (transpose, identity ), (transpose, transpose),
465+
(identity, Adjoint ), (Adjoint, identity ), (Adjoint, Adjoint ),
466+
(identity, Transpose), (Transpose, identity ), (Transpose, Transpose))
465467
@test *(transform1(D), transform2(B))::typeof(D) *(transform1(Matrix(D)), transform2(Matrix(B))) atol=2 * eps()
466468
@test *(transform1(DD), transform2(BB))::typeof(DD) == *(transform1(fullDD), transform2(fullBB))
467469
end
470+
M = randn(T, 5, 5)
471+
MM = [randn(T, 2, 2) for _ in 1:2, _ in 1:2]
472+
for transform in (identity, adjoint, transpose, Adjoint, Transpose)
473+
@test lmul!(transform(D), copy(M)) == *(transform(Matrix(D)), M)
474+
@test rmul!(copy(M), transform(D)) == *(M, transform(Matrix(D)))
475+
@test lmul!(transform(DD), copy(MM)) == *(transform(fullDD), MM)
476+
@test rmul!(copy(MM), transform(DD)) == *(MM, transform(fullDD))
477+
end
468478
end
469479
end
470480

@@ -474,10 +484,16 @@ end
474484
end
475485

476486
@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin
477-
x = rand(5)
478-
D = Diagonal(rand(5))
479-
@test x'*D*x == (x'*D)*x == (x'*Array(D))*x
480-
@test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x
487+
x = collect(1:2)
488+
xt = transpose(x)
489+
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
490+
D = Diagonal(A)
491+
@test x'*D == x'*A == copy(x')*D == copy(x')*A
492+
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
493+
y = [x, x]
494+
yt = transpose(y)
495+
@test y'*D*y == (y'*D)*y == (y'*A)*y
496+
@test yt*D*y == (yt*D)*y == (yt*A)*y
481497
end
482498

483499
@testset "Triangular division by Diagonal #27989" begin

0 commit comments

Comments
 (0)