Skip to content

Commit 0aa5201

Browse files
dkarraschKristofferC
authored andcommitted
Fix 3-arg dot for 1x1 structured matrices (#46473)
(cherry picked from commit c3d5009)
1 parent 0830823 commit 0aa5201

File tree

4 files changed

+49
-38
lines changed

4 files changed

+49
-38
lines changed

stdlib/LinearAlgebra/src/bidiag.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -702,14 +702,15 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
702702
require_one_based_indexing(x, y)
703703
nx, ny = length(x), length(y)
704704
(nx == size(B, 1) == ny) || throw(DimensionMismatch())
705-
if iszero(nx)
706-
return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
705+
if nx 1
706+
nx == 0 && return dot(zero(eltype(x)), zero(eltype(B)), zero(eltype(y)))
707+
return dot(x[1], B.dv[1], y[1])
707708
end
708709
ev, dv = B.ev, B.dv
709-
if B.uplo == 'U'
710+
@inbounds if B.uplo == 'U'
710711
x₀ = x[1]
711712
r = dot(x[1], dv[1], y[1])
712-
@inbounds for j in 2:nx-1
713+
for j in 2:nx-1
713714
x₋, x₀ = x₀, x[j]
714715
r += dot(adjoint(ev[j-1])*x₋ + adjoint(dv[j])*x₀, y[j])
715716
end
@@ -719,7 +720,7 @@ function dot(x::AbstractVector, B::Bidiagonal, y::AbstractVector)
719720
x₀ = x[1]
720721
x₊ = x[2]
721722
r = dot(adjoint(dv[1])*x₀ + adjoint(ev[1])*x₊, y[1])
722-
@inbounds for j in 2:nx-1
723+
for j in 2:nx-1
723724
x₀, x₊ = x₊, x[j+1]
724725
r += dot(adjoint(dv[j])*x₀ + adjoint(ev[j])*x₊, y[j])
725726
end

stdlib/LinearAlgebra/src/tridiag.jl

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -256,21 +256,24 @@ end
256256
function dot(x::AbstractVector, S::SymTridiagonal, y::AbstractVector)
257257
require_one_based_indexing(x, y)
258258
nx, ny = length(x), length(y)
259-
(nx == size(S, 1) == ny) || throw(DimensionMismatch())
260-
if iszero(nx)
261-
return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
259+
(nx == size(S, 1) == ny) || throw(DimensionMismatch("dot"))
260+
if nx 1
261+
nx == 0 && return dot(zero(eltype(x)), zero(eltype(S)), zero(eltype(y)))
262+
return dot(x[1], S.dv[1], y[1])
262263
end
263264
dv, ev = S.dv, S.ev
264-
x₀ = x[1]
265-
x₊ = x[2]
266-
sub = transpose(ev[1])
267-
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
268-
@inbounds for j in 2:nx-1
269-
x₋, x₀, x₊ = x₀, x₊, x[j+1]
270-
sup, sub = transpose(sub), transpose(ev[j])
271-
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
272-
end
273-
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
265+
@inbounds begin
266+
x₀ = x[1]
267+
x₊ = x[2]
268+
sub = transpose(ev[1])
269+
r = dot(adjoint(dv[1])*x₀ + adjoint(sub)*x₊, y[1])
270+
for j in 2:nx-1
271+
x₋, x₀, x₊ = x₀, x₊, x[j+1]
272+
sup, sub = transpose(sub), transpose(ev[j])
273+
r += dot(adjoint(sup)*x₋ + adjoint(dv[j])*x₀ + adjoint(sub)*x₊, y[j])
274+
end
275+
r += dot(adjoint(transpose(sub))*x₀ + adjoint(dv[nx])*x₊, y[nx])
276+
end
274277
return r
275278
end
276279

@@ -841,18 +844,21 @@ function dot(x::AbstractVector, A::Tridiagonal, y::AbstractVector)
841844
require_one_based_indexing(x, y)
842845
nx, ny = length(x), length(y)
843846
(nx == size(A, 1) == ny) || throw(DimensionMismatch())
844-
if iszero(nx)
845-
return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
846-
end
847-
x₀ = x[1]
848-
x₊ = x[2]
849-
dl, d, du = A.dl, A.d, A.du
850-
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
851-
@inbounds for j in 2:nx-1
852-
x₋, x₀, x₊ = x₀, x₊, x[j+1]
853-
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
854-
end
855-
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
847+
if nx 1
848+
nx == 0 && return dot(zero(eltype(x)), zero(eltype(A)), zero(eltype(y)))
849+
return dot(x[1], A.d[1], y[1])
850+
end
851+
@inbounds begin
852+
x₀ = x[1]
853+
x₊ = x[2]
854+
dl, d, du = A.dl, A.d, A.du
855+
r = dot(adjoint(d[1])*x₀ + adjoint(dl[1])*x₊, y[1])
856+
for j in 2:nx-1
857+
x₋, x₀, x₊ = x₀, x₊, x[j+1]
858+
r += dot(adjoint(du[j-1])*x₋ + adjoint(d[j])*x₀ + adjoint(dl[j])*x₊, y[j])
859+
end
860+
r += dot(adjoint(du[nx-1])*x₀ + adjoint(d[nx])*x₊, y[nx])
861+
end
856862
return r
857863
end
858864

stdlib/LinearAlgebra/test/bidiag.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,22 +623,22 @@ end
623623
end
624624

625625
@testset "generalized dot" begin
626-
for elty in (Float64, ComplexF64)
627-
dv = randn(elty, 5)
628-
ev = randn(elty, 4)
629-
x = randn(elty, 5)
630-
y = randn(elty, 5)
626+
for elty in (Float64, ComplexF64), n in (5, 1)
627+
dv = randn(elty, n)
628+
ev = randn(elty, n-1)
629+
x = randn(elty, n)
630+
y = randn(elty, n)
631631
for uplo in (:U, :L)
632632
B = Bidiagonal(dv, ev, uplo)
633-
@test dot(x, B, y) dot(B'x, y) dot(x, Matrix(B), y)
633+
@test dot(x, B, y) dot(B'x, y) dot(x, B*y) dot(x, Matrix(B), y)
634634
end
635635
dv = Vector{elty}(undef, 0)
636636
ev = Vector{elty}(undef, 0)
637637
x = Vector{elty}(undef, 0)
638638
y = Vector{elty}(undef, 0)
639639
for uplo in (:U, :L)
640640
B = Bidiagonal(dv, ev, uplo)
641-
@test dot(x, B, y) dot(zero(elty), zero(elty), zero(elty))
641+
@test dot(x, B, y) === zero(elty)
642642
end
643643
end
644644
end

stdlib/LinearAlgebra/test/tridiag.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,11 @@ end
434434
@testset "generalized dot" begin
435435
x = fill(convert(elty, 1), n)
436436
y = fill(convert(elty, 1), n)
437-
@test dot(x, A, y) dot(A'x, y)
437+
@test dot(x, A, y) dot(A'x, y) dot(x, A*y)
438+
@test dot([1], SymTridiagonal([1], Int[]), [1]) == 1
439+
@test dot([1], Tridiagonal(Int[], [1], Int[]), [1]) == 1
440+
@test dot(Int[], SymTridiagonal(Int[], Int[]), Int[]) === 0
441+
@test dot(Int[], Tridiagonal(Int[], Int[], Int[]), Int[]) === 0
438442
end
439443
end
440444
end

0 commit comments

Comments
 (0)