Skip to content

Commit 4be9339

Browse files
KlausCandreasnoack
authored andcommitted
proper diagonal in copytri! (fix #30055) (#30066)
* proper diagonal in copytri! (fix #30055) * added sprandn methods with Type * additional parameter in copytri! for diagonal * @inline copytri! to enforce constant propagation
1 parent 34f7a4a commit 4be9339

File tree

4 files changed

+36
-13
lines changed

4 files changed

+36
-13
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,14 +327,16 @@ function mul!(C::AbstractMatrix, adjA::Adjoint{<:Any,<:AbstractVecOrMat}, transB
327327
end
328328
# Supporting functions for matrix multiplication
329329

330-
function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false)
330+
# copy transposed(adjoint) of upper(lower) side-digonals. Optionally include diagonal.
331+
@inline function copytri!(A::AbstractMatrix, uplo::AbstractChar, conjugate::Bool=false, diag::Bool=false)
331332
n = checksquare(A)
333+
off = diag ? 0 : 1
332334
if uplo == 'U'
333-
for i = 1:(n-1), j = (i+1):n
335+
for i = 1:n, j = (i+off):n
334336
A[j,i] = conjugate ? adjoint(A[i,j]) : transpose(A[i,j])
335337
end
336338
elseif uplo == 'L'
337-
for i = 1:(n-1), j = (i+1):n
339+
for i = 1:n, j = (i+off):n
338340
A[i,j] = conjugate ? adjoint(A[j,i]) : transpose(A[j,i])
339341
end
340342
else

stdlib/LinearAlgebra/src/symmetric.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,14 @@ similar(A::Union{Symmetric,Hermitian}, ::Type{T}, dims::Dims{N}) where {T,N} = s
237237
function Matrix(A::Symmetric)
238238
B = copytri!(convert(Matrix, copy(A.data)), A.uplo)
239239
for i = 1:size(A, 1)
240-
B[i,i] = symmetric(B[i,i], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
240+
B[i,i] = symmetric(A[i,i], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
241241
end
242242
return B
243243
end
244244
function Matrix(A::Hermitian)
245245
B = copytri!(convert(Matrix, copy(A.data)), A.uplo, true)
246246
for i = 1:size(A, 1)
247-
B[i,i] = hermitian(B[i,i], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
247+
B[i,i] = hermitian(A[i,i], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
248248
end
249249
return B
250250
end

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,14 @@ Base.copy(A::Transpose{<:Any,<:UpperTriangular}) = transpose!(copy(A.parent))
346346
Base.copy(A::Transpose{<:Any,<:UnitLowerTriangular}) = transpose!(copy(A.parent))
347347
Base.copy(A::Transpose{<:Any,<:UnitUpperTriangular}) = transpose!(copy(A.parent))
348348

349-
transpose!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L'))
350-
transpose!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L'))
351-
transpose!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U'))
352-
transpose!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U'))
353-
adjoint!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L' , true))
354-
adjoint!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L' , true))
355-
adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true))
356-
adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true))
349+
transpose!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L', false, true))
350+
transpose!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L', false, true))
351+
transpose!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U', false, true))
352+
transpose!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U', false, true))
353+
adjoint!(A::LowerTriangular) = UpperTriangular(copytri!(A.data, 'L' , true, true))
354+
adjoint!(A::UnitLowerTriangular) = UnitUpperTriangular(copytri!(A.data, 'L' , true, true))
355+
adjoint!(A::UpperTriangular) = LowerTriangular(copytri!(A.data, 'U' , true, true))
356+
adjoint!(A::UnitUpperTriangular) = UnitLowerTriangular(copytri!(A.data, 'U' , true, true))
357357

358358
diag(A::LowerTriangular) = diag(A.data)
359359
diag(A::UnitLowerTriangular) = fill(one(eltype(A)), size(A,1))

stdlib/LinearAlgebra/test/matmul.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,27 @@ end
301301

302302
@test_throws ArgumentError LinearAlgebra.copytri!(Matrix{Float64}(undef,10,10),'Z')
303303

304+
@testset "Issue 30055" begin
305+
B = [1+im 2+im 3+im; 4+im 5+im 6+im; 7+im 9+im im]
306+
A = UpperTriangular(B)
307+
@test copy(transpose(A)) == transpose(A)
308+
@test copy(A') == A'
309+
A = LowerTriangular(B)
310+
@test copy(transpose(A)) == transpose(A)
311+
@test copy(A') == A'
312+
B = Matrix{Matrix{Complex{Int}}}(undef, 2, 2)
313+
B[1,1] = [1+im 2+im; 3+im 4+im]
314+
B[2,1] = [1+2im 1+3im;1+3im 1+4im]
315+
B[1,2] = [7+im 8+2im; 9+3im 4im]
316+
B[2,2] = [9+im 8+im; 7+im 6+im]
317+
A = UpperTriangular(B)
318+
@test copy(transpose(A)) == transpose(A)
319+
@test copy(A') == A'
320+
A = LowerTriangular(B)
321+
@test copy(transpose(A)) == transpose(A)
322+
@test copy(A') == A'
323+
end
324+
304325
@testset "gemv! and gemm_wrapper for $elty" for elty in [Float32,Float64,ComplexF64,ComplexF32]
305326
A10x10, x10, x11 = Array{elty}.(undef, ((10,10), 10, 11))
306327
@test_throws DimensionMismatch LinearAlgebra.gemv!(x10,'N',A10x10,x11)

0 commit comments

Comments
 (0)