Skip to content

Commit 2bbfc0e

Browse files
dkarraschlazarusA
authored andcommitted
Accomodate for rectangular matrices in copytrito! (JuliaLang#54587)
1 parent 65ae0e9 commit 2bbfc0e

File tree

4 files changed

+98
-19
lines changed

4 files changed

+98
-19
lines changed

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,19 +2014,24 @@ function copytrito!(B::AbstractMatrix, A::AbstractMatrix, uplo::AbstractChar)
20142014
BLAS.chkuplo(uplo)
20152015
m,n = size(A)
20162016
m1,n1 = size(B)
2017-
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
20182017
A = Base.unalias(B, A)
20192018
if uplo == 'U'
2020-
for j=1:n
2021-
for i=1:min(j,m)
2022-
@inbounds B[i,j] = A[i,j]
2023-
end
2019+
if n < m
2020+
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
2021+
else
2022+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
20242023
end
2025-
else # uplo == 'L'
2026-
for j=1:n
2027-
for i=j:m
2028-
@inbounds B[i,j] = A[i,j]
2029-
end
2024+
for j in 1:n, i in 1:min(j,m)
2025+
@inbounds B[i,j] = A[i,j]
2026+
end
2027+
else # uplo == 'L'
2028+
if m < n
2029+
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
2030+
else
2031+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
2032+
end
2033+
for j in 1:n, i in j:m
2034+
@inbounds B[i,j] = A[i,j]
20302035
end
20312036
end
20322037
return B

stdlib/LinearAlgebra/src/lapack.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7157,9 +7157,23 @@ for (fn, elty) in ((:dlacpy_, :Float64),
71577157
function lacpy!(B::AbstractMatrix{$elty}, A::AbstractMatrix{$elty}, uplo::AbstractChar)
71587158
require_one_based_indexing(A, B)
71597159
chkstride1(A, B)
7160-
m,n = size(A)
7161-
m1,n1 = size(B)
7162-
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
7160+
m, n = size(A)
7161+
m1, n1 = size(B)
7162+
if uplo == 'U'
7163+
if n < m
7164+
(m1 < n || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($n,$n)"))
7165+
else
7166+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7167+
end
7168+
elseif uplo == 'L'
7169+
if m < n
7170+
(m1 < m || n1 < m) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$m)"))
7171+
else
7172+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7173+
end
7174+
else
7175+
(m1 < m || n1 < n) && throw(DimensionMismatch(lazy"B of size ($m1,$n1) should have at least size ($m,$n)"))
7176+
end
71637177
lda = max(1, stride(A, 2))
71647178
ldb = max(1, stride(B, 2))
71657179
ccall((@blasfunc($fn), libblastrampoline), Cvoid,

stdlib/LinearAlgebra/test/generic.jl

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -654,12 +654,54 @@ end
654654

655655
@testset "copytrito!" begin
656656
n = 10
657-
for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U')
658-
for AA in (A, view(A, reverse.(axes(A))...))
659-
for B in (zeros(n, n), zeros(n+1, n+2))
660-
copytrito!(B, AA, uplo)
657+
@testset "square" begin
658+
for A in (rand(n, n), rand(Int8, n, n)), uplo in ('L', 'U')
659+
for AA in (A, view(A, reverse.(axes(A))...))
661660
C = uplo == 'L' ? tril(AA) : triu(AA)
662-
@test view(B, 1:n, 1:n) == C
661+
for B in (zeros(n, n), zeros(n+1, n+2))
662+
copytrito!(B, AA, uplo)
663+
@test view(B, 1:n, 1:n) == C
664+
end
665+
end
666+
end
667+
end
668+
@testset "wide" begin
669+
for A in (rand(n, 2n), rand(Int8, n, 2n))
670+
for AA in (A, view(A, reverse.(axes(A))...))
671+
C = tril(AA)
672+
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
673+
B = zeros(M, N)
674+
copytrito!(B, AA, 'L')
675+
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
676+
end
677+
@test_throws DimensionMismatch copytrito!(zeros(n-1, 2n), AA, 'L')
678+
C = triu(AA)
679+
for (M, N) in ((n, 2n), (n+1, 2n), (n, 2n+1), (n+1, 2n+1))
680+
B = zeros(M, N)
681+
copytrito!(B, AA, 'U')
682+
@test view(B, 1:n, 1:2n) == view(C, 1:n, 1:2n)
683+
end
684+
@test_throws DimensionMismatch copytrito!(zeros(n+1, 2n-1), AA, 'U')
685+
end
686+
end
687+
end
688+
@testset "tall" begin
689+
for A in (rand(2n, n), rand(Int8, 2n, n))
690+
for AA in (A, view(A, reverse.(axes(A))...))
691+
C = triu(AA)
692+
for (M, N) in ((n, n), (n+1, n), (n, n+1), (n+1, n+1))
693+
B = zeros(M, N)
694+
copytrito!(B, AA, 'U')
695+
@test view(B, 1:n, 1:n) == view(C, 1:n, 1:n)
696+
end
697+
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'U')
698+
C = tril(AA)
699+
for (M, N) in ((2n, n), (2n, n+1), (2n+1, n), (2n+1, n+1))
700+
B = zeros(M, N)
701+
copytrito!(B, AA, 'L')
702+
@test view(B, 1:2n, 1:n) == view(C, 1:2n, 1:n)
703+
end
704+
@test_throws DimensionMismatch copytrito!(zeros(n-1, n+1), AA, 'L')
663705
end
664706
end
665707
end

stdlib/LinearAlgebra/test/lapack.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,26 @@ end
805805
B = zeros(elty, n, n)
806806
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
807807
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
808-
@test B C
808+
@test B == C
809+
B = zeros(elty, n+1, n+1)
810+
LinearAlgebra.LAPACK.lacpy!(B, A, uplo)
811+
C = uplo == 'L' ? tril(A) : (uplo == 'U' ? triu(A) : A)
812+
@test view(B, 1:n, 1:n) == C
809813
end
814+
A = rand(elty, n, n+1)
815+
B = zeros(elty, n, n)
816+
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
817+
@test B == view(tril(A), 1:n, 1:n)
818+
B = zeros(elty, n, n+1)
819+
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
820+
@test B == triu(A)
821+
A = rand(elty, n+1, n)
822+
B = zeros(elty, n, n)
823+
LinearAlgebra.LAPACK.lacpy!(B, A, 'U')
824+
@test B == view(triu(A), 1:n, 1:n)
825+
B = zeros(elty, n+1, n)
826+
LinearAlgebra.LAPACK.lacpy!(B, A, 'L')
827+
@test B == tril(A)
810828
end
811829
end
812830

0 commit comments

Comments
 (0)