Skip to content

Fix precision of 2x2 eigenvectors #1191

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "StaticArrays"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
version = "1.6.2"
version = "1.6.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
62 changes: 20 additions & 42 deletions src/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,48 +159,26 @@ end
@inline function _eig(::Size{(2,2)}, A::LinearAlgebra.RealHermSymComplexHerm{T}, permute, scale) where {T <: Real}
a = A.data
TA = eltype(A)
@inbounds if A.uplo == 'U'
if !iszero(a[3]) # A is not diagonal
t_half = real(a[1] + a[4]) / 2
tmp = norm(SVector((a[1] - a[4])/2, a[3]'))
vals = SVector(t_half - tmp, t_half + tmp)

v11 = vals[1] - a[4]
n1 = sqrt(v11' * v11 + a[3]' * a[3])
v11 = v11 / n1
v12 = a[3]' / n1

v21 = vals[2] - a[4]
n2 = sqrt(v21' * v21 + a[3]' * a[3])
v21 = v21 / n2
v22 = a[3]' / n2

vecs = @SMatrix [ v11 v21 ;
v12 v22 ]

return Eigen(vals, vecs)
end
else # A.uplo == 'L'
if !iszero(a[2]) # A is not diagonal
t_half = real(a[1] + a[4]) / 2
tmp = norm(SVector((a[1] - a[4])/2, a[2]))
vals = SVector(t_half - tmp, t_half + tmp)

v11 = vals[1] - a[4]
n1 = sqrt(v11' * v11 + a[2]' * a[2])
v11 = v11 / n1
v12 = a[2] / n1

v21 = vals[2] - a[4]
n2 = sqrt(v21' * v21 + a[2]' * a[2])
v21 = v21 / n2
v22 = a[2] / n2

vecs = @SMatrix [ v11 v21 ;
v12 v22 ]

return Eigen(vals,vecs)
end
@inbounds a21 = A.uplo == 'U' ? a[3]' : a[2]
@inbounds if !iszero(a21) # A is not diagonal
t_half = real(a[1] + a[4]) / 2
diag_avg_diff = (a[1] - a[4])/2
tmp = norm(SVector(diag_avg_diff, a21))
vals = SVector(t_half - tmp, t_half + tmp)
v11 = -tmp + diag_avg_diff
n1 = sqrt(v11' * v11 + a21 * a21')
v11 = v11 / n1
v12 = a21 / n1

v21 = tmp + diag_avg_diff
n2 = sqrt(v21' * v21 + a21 * a21')
v21 = v21 / n2
v22 = a21 / n2

vecs = @SMatrix [ v11 v21 ;
v12 v22 ]

return Eigen(vals, vecs)
end

# A must be diagonal if we reached this point; treatment of uplo 'L' and 'U' is then identical
Expand Down
19 changes: 19 additions & 0 deletions test/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ using StaticArrays, Test, LinearAlgebra
@test eigvecs(E) * SDiagonal(eigvals(E)) * eigvecs(E)' ≈ A
end

# issue #956
v1, v2 = let A = SMatrix{2,2}((1.0, 1e-100, 1e-100, 1.0))
vecs = eigvecs(eigen(Hermitian(A)))
vecs[:,1], vecs[:,2]
end
@test norm(v1) ≈ 1
@test norm(v2) ≈ 1
@test acos(v1 ⋅ v2)*2 ≈ π

# Test that correct value is taken for upper or lower sym
for T in (Float64, Complex{Float64})
af = SMatrix{2,2}(rand(T, 4))
a = af + af' # Make hermitian
au = Hermitian(SMatrix{2,2}((a[1,1], zero(T), a[1,2], a[2,2])), 'U')
al = Hermitian(SMatrix{2,2}((a[1,1], a[2,1], zero(T), a[2,2])), 'L')
@test eigvals(eigen(a)) ≈ eigvals(eigen(au)) ≈ eigvals(eigen(al))
@test eigvals(eigen(a)) ≈ eigvals(eigen(au)) ≈ eigvals(eigen(al))
end

m1_a = randn(2,2)
m1_a = m1_a*m1_a'
m1 = SMatrix{2,2}(m1_a)
Expand Down
6 changes: 5 additions & 1 deletion test/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,11 @@ if VERSION >= v"1.7.0"
@test SA[1 2;1 2 ;;; 1 2;1 2] === SArray{Tuple{2,2,2}}(Tuple([1 2;1 2 ;;; 1 2;1 2]))
@test_inlined SA_test_hvncat1(3)
@test_inlined SA_test_hvncat2(2)
@test_throws ArgumentError SA[1;2;;3]
if VERSION < v"1.10-DEV"
@test_throws ArgumentError SA[1;2;;3]
else
@test_throws DimensionMismatch SA[1;2;;3]
end
end

# https://github.com/JuliaArrays/StaticArrays.jl/pull/685
Expand Down