Skip to content

Commit 9eb7a0c

Browse files
authored
Improve error message in inplace transpose (#54669)
1 parent fa038d9 commit 9eb7a0c

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

stdlib/LinearAlgebra/src/transpose.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,27 +74,32 @@ julia> A
7474
```
7575
"""
7676
adjoint!(B::AbstractMatrix, A::AbstractMatrix) = transpose_f!(adjoint, B, A)
77+
78+
@noinline function check_transpose_axes(axesA, axesB)
79+
axesB == reverse(axesA) || throw(DimensionMismatch("axes of the destination are incompatible with that of the source"))
80+
end
81+
7782
function transpose!(B::AbstractVector, A::AbstractMatrix)
78-
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose"))
83+
check_transpose_axes((axes(B,1), axes(B,2)), axes(A))
7984
copyto!(B, A)
8085
end
8186
function transpose!(B::AbstractMatrix, A::AbstractVector)
82-
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose"))
87+
check_transpose_axes(axes(B), (axes(A,1), axes(A,2)))
8388
copyto!(B, A)
8489
end
8590
function adjoint!(B::AbstractVector, A::AbstractMatrix)
86-
axes(B,1) == axes(A,2) && axes(A,1) == 1:1 || throw(DimensionMismatch("transpose"))
91+
check_transpose_axes((axes(B,1), axes(B,2)), axes(A))
8792
ccopy!(B, A)
8893
end
8994
function adjoint!(B::AbstractMatrix, A::AbstractVector)
90-
axes(B,2) == axes(A,1) && axes(B,1) == 1:1 || throw(DimensionMismatch("transpose"))
95+
check_transpose_axes(axes(B), (axes(A,1), axes(A,2)))
9196
ccopy!(B, A)
9297
end
9398

9499
const transposebaselength=64
95100
function transpose_f!(f, B::AbstractMatrix, A::AbstractMatrix)
96101
inds = axes(A)
97-
axes(B,1) == inds[2] && axes(B,2) == inds[1] || throw(DimensionMismatch(string(f)))
102+
check_transpose_axes(axes(B), inds)
98103

99104
m, n = length(inds[1]), length(inds[2])
100105
if m*n<=4*transposebaselength

stdlib/LinearAlgebra/test/adjtrans.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,4 +703,14 @@ end
703703
@test B == At
704704
end
705705

706+
@testset "error message in transpose" begin
707+
v = zeros(2)
708+
A = zeros(1,1)
709+
B = zeros(2,3)
710+
for (t1, t2) in Any[(A, v), (v, A), (A, B)]
711+
@test_throws "axes of the destination are incompatible with that of the source" transpose!(t1, t2)
712+
@test_throws "axes of the destination are incompatible with that of the source" adjoint!(t1, t2)
713+
end
714+
end
715+
706716
end # module TestAdjointTranspose

0 commit comments

Comments
 (0)