Skip to content

Commit 880a9fe

Browse files
authored
Add _sym_uplo to skip validation (#1441)
1 parent 8d6ca14 commit 880a9fe

File tree

7 files changed

+74
-48
lines changed

7 files changed

+74
-48
lines changed

src/LinearAlgebra.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,14 @@ function char_uplo(uplo::Symbol)
364364
end
365365
end
366366

367+
"""
368+
sym_uplo(uplo::Char)
369+
370+
Return the `Symbol` corresponding the `uplo` by checking for validity.
371+
"""
367372
function sym_uplo(uplo::Char)
373+
# This method is called by other packages, and isn't used within LinearAlgebra
374+
# It's retained here for backward compatibility.
368375
if uplo == 'U'
369376
return :U
370377
elseif uplo == 'L'
@@ -373,6 +380,13 @@ function sym_uplo(uplo::Char)
373380
throw_uplo()
374381
end
375382
end
383+
"""
384+
_sym_uplo(uplo::Char)
385+
386+
Return the `Symbol` corresponding to `uplo` without checking for validity.
387+
See also `sym_uplo`, which checks for validity.
388+
"""
389+
_sym_uplo(uplo::Char) = uplo == 'U' ? (:U) : (:L)
376390

377391
@noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)"))
378392

src/bidiag.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ function show(io::IO, M::Bidiagonal)
302302
print(io, ", ")
303303
show(io, M.ev)
304304
print(io, ", ")
305-
show(io, sym_uplo(M.uplo))
305+
show(io, _sym_uplo(M.uplo))
306306
print(io, ")")
307307
end
308308

src/bunchkaufman.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,9 @@ function bunchkaufman!(A::StridedMatrix{<:BlasFloat}, rook::Bool = false; check:
130130
end
131131

132132
bkcopy_oftype(A, S) = eigencopy_oftype(A, S)
133-
bkcopy_oftype(A::Symmetric{<:Complex}, S) = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
133+
function bkcopy_oftype(A::Symmetric{<:Complex}, S)
134+
Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), _sym_uplo(A.uplo))
135+
end
134136

135137
"""
136138
bunchkaufman(A, rook::Bool=false; check = true) -> S::BunchKaufman

src/special.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -292,19 +292,19 @@ end
292292

293293
for f in (:+, :-)
294294
@eval function $f(D::Diagonal{<:Number}, S::Symmetric)
295-
uplo = sym_uplo(S.uplo)
295+
uplo = _sym_uplo(S.uplo)
296296
return Symmetric(parentof_applytri($f, Symmetric(D, uplo), S), uplo)
297297
end
298298
@eval function $f(S::Symmetric, D::Diagonal{<:Number})
299-
uplo = sym_uplo(S.uplo)
299+
uplo = _sym_uplo(S.uplo)
300300
return Symmetric(parentof_applytri($f, S, Symmetric(D, uplo)), uplo)
301301
end
302302
@eval function $f(D::Diagonal{<:Real}, H::Hermitian)
303-
uplo = sym_uplo(H.uplo)
303+
uplo = _sym_uplo(H.uplo)
304304
return Hermitian(parentof_applytri($f, Hermitian(D, uplo), H), uplo)
305305
end
306306
@eval function $f(H::Hermitian, D::Diagonal{<:Real})
307-
uplo = sym_uplo(H.uplo)
307+
uplo = _sym_uplo(H.uplo)
308308
return Hermitian(parentof_applytri($f, H, Hermitian(D, uplo)), uplo)
309309
end
310310
end
@@ -608,8 +608,8 @@ end
608608
# tridiagonal cholesky factorization
609609
function cholesky(S::RealSymHermitian{<:BiTriSym}, ::NoPivot = NoPivot(); check::Bool = true)
610610
T = choltype(S)
611-
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), sym_uplo(S.uplo))
612-
cholesky!(Hermitian(B, sym_uplo(S.uplo)), NoPivot(); check = check)
611+
B = Bidiagonal{T}(diag(S, 0), diag(S, S.uplo == 'U' ? 1 : -1), _sym_uplo(S.uplo))
612+
cholesky!(Hermitian(B, _sym_uplo(S.uplo)), NoPivot(); check = check)
613613
end
614614

615615
# istriu/istril for triangular wrappers of structured matrices

src/symmetric.jl

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -206,15 +206,15 @@ for (S, H) in ((:Symmetric, :Hermitian), (:Hermitian, :Symmetric))
206206
throw(ArgumentError("Cannot construct $($S); uplo doesn't match"))
207207
end
208208
end
209-
$S(A::$H) = $S(A, sym_uplo(A.uplo))
209+
$S(A::$H) = $S(A, _sym_uplo(A.uplo))
210210
function $S(A::$H, uplo::Symbol)
211211
if A.uplo == char_uplo(uplo)
212212
if $H === Hermitian && !(eltype(A) <: Real) &&
213213
any(!isreal, A.data[i] for i in diagind(A.data, IndexStyle(A.data)))
214214

215215
throw(ArgumentError("Cannot construct $($S)($($H))); diagonal contains complex values"))
216216
end
217-
return $S(A.data, sym_uplo(A.uplo))
217+
return $S(A.data, _sym_uplo(A.uplo))
218218
else
219219
throw(ArgumentError("Cannot construct $($S); uplo doesn't match"))
220220
end
@@ -286,7 +286,7 @@ end
286286
@inline function getindex(A::Symmetric, i::Int, j::Int)
287287
@boundscheck checkbounds(A, i, j)
288288
@inbounds if i == j
289-
return symmetric(A.data[i, j], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
289+
return symmetric(A.data[i, j], _sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
290290
elseif (A.uplo == 'U') == (i < j)
291291
return A.data[i, j]
292292
else
@@ -296,7 +296,7 @@ end
296296
@inline function getindex(A::Hermitian, i::Int, j::Int)
297297
@boundscheck checkbounds(A, i, j)
298298
@inbounds if i == j
299-
return hermitian(A.data[i, j], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
299+
return hermitian(A.data[i, j], _sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
300300
elseif (A.uplo == 'U') == (i < j)
301301
return A.data[i, j]
302302
else
@@ -329,14 +329,14 @@ Base._reverse(A::Hermitian, ::Colon) = Hermitian(reverse(A.data), A.uplo == 'U'
329329
end
330330

331331
Base.dataids(A::HermOrSym) = Base.dataids(parent(A))
332-
Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo))
333-
Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo))
332+
Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), _sym_uplo(A.uplo))
333+
Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), _sym_uplo(A.uplo))
334334

335335
_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose
336336
_conjugation(::Hermitian) = adjoint
337337

338-
diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo))
339-
diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo))
338+
diag(A::Symmetric) = symmetric.(diag(parent(A)), _sym_uplo(A.uplo))
339+
diag(A::Hermitian) = hermitian.(diag(parent(A)), _sym_uplo(A.uplo))
340340

341341
function applytri(f, A::HermOrSym)
342342
if A.uplo == 'U'
@@ -374,15 +374,15 @@ similar(A::Union{Symmetric,Hermitian}, ::Type{T}, dims::Dims{N}) where {T,N} = s
374374
parent(A::HermOrSym) = A.data
375375
Symmetric{T,S}(A::Symmetric{T,S}) where {T,S<:AbstractMatrix{T}} = A
376376
Symmetric{T,S}(A::Symmetric) where {T,S<:AbstractMatrix{T}} = Symmetric{T,S}(convert(S,A.data),A.uplo)
377-
AbstractMatrix{T}(A::Symmetric) where {T} = Symmetric(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo))
377+
AbstractMatrix{T}(A::Symmetric) where {T} = Symmetric(convert(AbstractMatrix{T}, A.data), _sym_uplo(A.uplo))
378378
AbstractMatrix{T}(A::Symmetric{T}) where {T} = copy(A)
379379
Hermitian{T,S}(A::Hermitian{T,S}) where {T,S<:AbstractMatrix{T}} = A
380380
Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(convert(S,A.data),A.uplo)
381-
AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo))
381+
AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), _sym_uplo(A.uplo))
382382
AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A)
383383

384-
copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), sym_uplo(A.uplo)))
385-
copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), sym_uplo(A.uplo)))
384+
copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), _sym_uplo(A.uplo)))
385+
copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), _sym_uplo(A.uplo)))
386386

387387
function copyto!(dest::Symmetric, src::Symmetric)
388388
if axes(dest) != axes(src)
@@ -423,13 +423,13 @@ end
423423
end
424424
@inline function _symmetrize_diagonal!(B, A::Symmetric)
425425
for i = 1:size(A, 1)
426-
B[i,i] = symmetric(A[i,i], sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
426+
B[i,i] = symmetric(A[i,i], _sym_uplo(A.uplo))::symmetric_type(eltype(A.data))
427427
end
428428
return B
429429
end
430430
@inline function _symmetrize_diagonal!(B, A::Hermitian)
431431
for i = 1:size(A, 1)
432-
B[i,i] = hermitian(A[i,i], sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
432+
B[i,i] = hermitian(A[i,i], _sym_uplo(A.uplo))::hermitian_type(eltype(A.data))
433433
end
434434
return B
435435
end
@@ -501,9 +501,9 @@ transpose(A::Hermitian{<:Real}) = A
501501

502502
real(A::Symmetric{<:Real}) = A
503503
real(A::Hermitian{<:Real}) = A
504-
real(A::Symmetric) = Symmetric(parentof_applytri(real, A), sym_uplo(A.uplo))
505-
real(A::Hermitian) = Hermitian(parentof_applytri(real, A), sym_uplo(A.uplo))
506-
imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), sym_uplo(A.uplo))
504+
real(A::Symmetric) = Symmetric(parentof_applytri(real, A), _sym_uplo(A.uplo))
505+
real(A::Hermitian) = Hermitian(parentof_applytri(real, A), _sym_uplo(A.uplo))
506+
imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), _sym_uplo(A.uplo))
507507

508508
Base.copy(A::Adjoint{<:Any,<:Symmetric}) =
509509
Symmetric(copy(adjoint(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U))
@@ -513,8 +513,8 @@ Base.copy(A::Transpose{<:Any,<:Hermitian}) =
513513
tr(A::Symmetric{<:Number}) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations)
514514
tr(A::Hermitian{<:Number}) = real(tr(A.data))
515515

516-
Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo))
517-
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo))
516+
Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), _sym_uplo(A.uplo))
517+
Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), _sym_uplo(A.uplo))
518518
Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo)
519519

520520
# tril/triu
@@ -731,25 +731,25 @@ function _hermkron!(C, A, B, conj, real, Auplo, Buplo)
731731
end
732732
end
733733

734-
(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo))
735-
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo))
734+
(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), _sym_uplo(A.uplo))
735+
(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), _sym_uplo(A.uplo))
736736

737737
## Addition/subtraction
738738
for f (:+, :-), Wrapper (:Hermitian, :Symmetric)
739739
@eval function $f(A::$Wrapper, B::$Wrapper)
740-
uplo = A.uplo == B.uplo ? sym_uplo(A.uplo) : (:U)
740+
uplo = A.uplo == B.uplo ? _sym_uplo(A.uplo) : (:U)
741741
$Wrapper(parentof_applytri($f, A, B), uplo)
742742
end
743743
end
744744

745745
for f in (:+, :-)
746746
@eval begin
747-
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), sym_uplo(B.uplo)))
748-
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), sym_uplo(A.uplo)), B)
749-
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, sym_uplo(B.uplo)), B)
750-
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, sym_uplo(A.uplo)))
751-
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, sym_uplo(B.uplo)), B)
752-
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, sym_uplo(A.uplo)))
747+
$f(A::Hermitian, B::Symmetric{<:Real}) = $f(A, Hermitian(parent(B), _sym_uplo(B.uplo)))
748+
$f(A::Symmetric{<:Real}, B::Hermitian) = $f(Hermitian(parent(A), _sym_uplo(A.uplo)), B)
749+
$f(A::SymTridiagonal, B::Symmetric) = $f(Symmetric(A, _sym_uplo(B.uplo)), B)
750+
$f(A::Symmetric, B::SymTridiagonal) = $f(A, Symmetric(B, _sym_uplo(A.uplo)))
751+
$f(A::SymTridiagonal{<:Real}, B::Hermitian) = $f(Hermitian(A, _sym_uplo(B.uplo)), B)
752+
$f(A::Hermitian, B::SymTridiagonal{<:Real}) = $f(A, Hermitian(B, _sym_uplo(A.uplo)))
753753
end
754754
end
755755

@@ -799,12 +799,12 @@ function dot(x::AbstractVector, A::HermOrSym, y::AbstractVector)
799799
end
800800

801801
# Scaling with Number
802-
*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
803-
*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
804-
*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo))
805-
*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo))
806-
/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))
807-
/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo))
802+
*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), _sym_uplo(A.uplo))
803+
*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), _sym_uplo(A.uplo))
804+
*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), _sym_uplo(A.uplo))
805+
*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), _sym_uplo(A.uplo))
806+
/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), _sym_uplo(A.uplo))
807+
/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), _sym_uplo(A.uplo))
808808

809809
factorize(A::HermOrSym) = _factorize(A)
810810
function _factorize(A::HermOrSym{T}; check::Bool=true) where T
@@ -850,8 +850,8 @@ function _inv(A::HermOrSym)
850850
B
851851
end
852852
# StridedMatrix restriction seems necessary due to inv! call in _inv above
853-
inv(A::Hermitian{<:Any,<:StridedMatrix}) = Hermitian(_inv(A), sym_uplo(A.uplo))
854-
inv(A::Symmetric{<:Any,<:StridedMatrix}) = Symmetric(_inv(A), sym_uplo(A.uplo))
853+
inv(A::Hermitian{<:Any,<:StridedMatrix}) = Hermitian(_inv(A), _sym_uplo(A.uplo))
854+
inv(A::Symmetric{<:Any,<:StridedMatrix}) = Symmetric(_inv(A), _sym_uplo(A.uplo))
855855

856856
function svd(A::RealHermSymComplexHerm; full::Bool=false)
857857
vals, vecs = eigen(A)

src/symmetriceigen.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@
22

33
# preserve HermOrSym wrapper
44
# Call `copytrito!` instead of `copy_similar` to only copy the matching triangular half
5-
eigencopy_oftype(A::Hermitian, ::Type{S}) where S = Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
6-
eigencopy_oftype(A::Symmetric, ::Type{S}) where S = Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), sym_uplo(A.uplo))
5+
function eigencopy_oftype(A::Hermitian, ::Type{S}) where S
6+
Hermitian(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), _sym_uplo(A.uplo))
7+
end
8+
function eigencopy_oftype(A::Symmetric, ::Type{S}) where S
9+
Symmetric(copytrito!(similar(parent(A), S, size(A)), A.data, A.uplo), _sym_uplo(A.uplo))
10+
end
711
eigencopy_oftype(A::Symmetric{<:Complex}, ::Type{S}) where S = copyto!(similar(parent(A), S), A)
812

913
"""
@@ -314,8 +318,8 @@ end
314318

315319
# Perform U' \ A / U in-place, where U::Union{UpperTriangular,Diagonal}
316320
UtiAUi!(A, U) = _UtiAUi!(A, U)
317-
UtiAUi!(A::Symmetric, U) = Symmetric(_UtiAUi!(copytri!(parent(A), A.uplo), U), sym_uplo(A.uplo))
318-
UtiAUi!(A::Hermitian, U) = Hermitian(_UtiAUi!(copytri!(parent(A), A.uplo, true), U), sym_uplo(A.uplo))
321+
UtiAUi!(A::Symmetric, U) = Symmetric(_UtiAUi!(copytri!(parent(A), A.uplo), U), _sym_uplo(A.uplo))
322+
UtiAUi!(A::Hermitian, U) = Hermitian(_UtiAUi!(copytri!(parent(A), A.uplo, true), U), _sym_uplo(A.uplo))
319323
_UtiAUi!(A, U) = rdiv!(ldiv!(U', A), U)
320324

321325
function eigvals(A::HermOrSym{TA}, B::HermOrSym{TB}; kws...) where {TA,TB}

test/symmetric.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,12 @@ end
13431343
@test_throws msg LinearAlgebra.fillband!(Symmetric(A), 2, 0, 1)
13441344
end
13451345

1346+
@testset "sym_uplo" begin
1347+
@test LinearAlgebra.sym_uplo('U') == :U
1348+
@test LinearAlgebra.sym_uplo('L') == :L
1349+
@test_throws ArgumentError LinearAlgebra.sym_uplo('N')
1350+
end
1351+
13461352
@testset "uplo" begin
13471353
S = Symmetric([1 2; 3 4], :U)
13481354
@test LinearAlgebra.uplo(S) == :U

0 commit comments

Comments
 (0)