Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 136 additions & 70 deletions src/tensors/factorizations.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# Tensor factorization
#----------------------
function factorisation_scalartype(t::AbstractTensorMap)
T = scalartype(t)
return promote_type(Float32, typeof(zero(T) / sqrt(abs2(one(T)))))
end
factorisation_scalartype(f, t) = factorisation_scalartype(t)

function permutedcopy_oftype(t::AbstractTensorMap, T::Type{<:Number}, p::Index2Tuple)
return permute!(similar(t, T, permute(space(t), p)), t, p)
end
function copy_oftype(t::AbstractTensorMap, T::Type{<:Number})
return copy!(similar(t, T), t)
end

"""
tsvd(t::AbstractTensorMap, (leftind, rightind)::Index2Tuple;
trunc::TruncationScheme = notrunc(), p::Real = 2, alg::Union{SVD, SDD} = SDD())
Expand Down Expand Up @@ -36,13 +49,14 @@
Orthogonality requires `InnerProductStyle(t) <: HasInnerProduct`, and `tsvd(!)`
is currently only implemented for `InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function tsvd(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return tsvd!(permute(t, (p₁, p₂); copy=true); kwargs...)
function tsvd(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(tsvd, t), p)
return tsvd!(tcopy; kwargs...)
end

LinearAlgebra.svdvals(t::AbstractTensorMap) = LinearAlgebra.svdvals!(copy(t))
function LinearAlgebra.svdvals!(t::AbstractTensorMap)
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
function LinearAlgebra.svdvals(t::AbstractTensorMap)
tcopy = copy_oftype(t, factorisation_scalartype(tsvd, t))
return LinearAlgebra.svdvals!(tcopy)
end

"""
Expand All @@ -67,8 +81,9 @@
`leftorth(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function leftorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return leftorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
function leftorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftorth, t), p)
return leftorth!(tcopy; kwargs...)
end

"""
Expand All @@ -95,8 +110,9 @@
`rightorth(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function rightorth(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return rightorth!(permute(t, (p₁, p₂); copy=true); kwargs...)
function rightorth(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightorth, t), p)
return rightorth!(tcopy; kwargs...)
end

"""
Expand All @@ -121,8 +137,9 @@
`leftnull(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function leftnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return leftnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
function leftnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(leftnull, t), p)
return leftnull!(tcopy; kwargs...)
end

"""
Expand All @@ -149,8 +166,9 @@
`rightnull(!)` is currently only implemented for
`InnerProductStyle(t) === EuclideanInnerProduct()`.
"""
function rightnull(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return rightnull!(permute(t, (p₁, p₂); copy=true); kwargs...)
function rightnull(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(rightnull, t), p)
return rightnull!(tcopy; kwargs...)
end

"""
Expand All @@ -172,17 +190,14 @@

See also `eig` and `eigh`
"""
function LinearAlgebra.eigen(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple;
kwargs...)
return eigen!(permute(t, (p₁, p₂); copy=true); kwargs...)
function LinearAlgebra.eigen(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigen, t), p)
return eigen!(tcopy; kwargs...)
end

function LinearAlgebra.eigvals(t::AbstractTensorMap; kwargs...)
return LinearAlgebra.eigvals!(copy(t); kwargs...)
end
function LinearAlgebra.eigvals!(t::AbstractTensorMap; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
tcopy = copy_oftype(t, factorisation_scalartype(eigen, t))
return LinearAlgebra.eigvals!(tcopy; kwargs...)
end

"""
Expand All @@ -207,8 +222,9 @@

See also `eigen` and `eigh`.
"""
function eig(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple; kwargs...)
return eig!(permute(t, (p₁, p₂); copy=true); kwargs...)
function eig(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eig, t), p)
return eig!(tcopy; kwargs...)

Check warning on line 227 in src/tensors/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/factorizations.jl#L225-L227

Added lines #L225 - L227 were not covered by tests
end

"""
Expand All @@ -231,8 +247,9 @@

See also `eigen` and `eig`.
"""
function eigh(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return eigh!(permute(t, (p₁, p₂); copy=true))
function eigh(t::AbstractTensorMap, p::Index2Tuple; kwargs...)
tcopy = permutedcopy_oftype(t, factorisation_scalartype(eigh, t), p)
return eigh!(tcopy; kwargs...)

Check warning on line 252 in src/tensors/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/factorizations.jl#L250-L252

Added lines #L250 - L252 were not covered by tests
end

"""
Expand All @@ -247,31 +264,54 @@
meaningless.
"""
function LinearAlgebra.isposdef(t::AbstractTensorMap, (p₁, p₂)::Index2Tuple)
return isposdef!(permute(t, (p₁, p₂); copy=true))
tcopy = permutedcopy_oftype(t, factorisation_scalartype(isposdef, t), p)
return isposdef!(tcopy)

Check warning on line 268 in src/tensors/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/factorizations.jl#L267-L268

Added lines #L267 - L268 were not covered by tests
end

tsvd(t::AbstractTensorMap; kwargs...) = tsvd!(copy(t); kwargs...)
function tsvd(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return tsvd!(tcopy; kwargs...)
end
function leftorth(t::AbstractTensorMap; alg::OFA=QRpos(), kwargs...)
return leftorth!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return leftorth!(tcopy; alg=alg, kwargs...)
end
function rightorth(t::AbstractTensorMap; alg::OFA=LQpos(), kwargs...)
return rightorth!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return rightorth!(tcopy; alg=alg, kwargs...)
end
function leftnull(t::AbstractTensorMap; alg::OFA=QR(), kwargs...)
return leftnull!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return leftnull!(tcopy; alg=alg, kwargs...)
end
function rightnull(t::AbstractTensorMap; alg::OFA=LQ(), kwargs...)
return rightnull!(copy(t); alg=alg, kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return rightnull!(tcopy; alg=alg, kwargs...)
end
function LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eigen!(tcopy; kwargs...)
end
function eig(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eig!(tcopy; kwargs...)
end
function eigh(t::AbstractTensorMap; kwargs...)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return eigh!(tcopy; kwargs...)
end
function LinearAlgebra.isposdef(t::AbstractTensorMap)
tcopy = copy!(similar(t, float(scalartype(t))), t)
return isposdef!(tcopy)
end
LinearAlgebra.eigen(t::AbstractTensorMap; kwargs...) = eigen!(copy(t); kwargs...)
eig(t::AbstractTensorMap; kwargs...) = eig!(copy(t); kwargs...)
eigh(t::AbstractTensorMap; kwargs...) = eigh!(copy(t); kwargs...)
LinearAlgebra.isposdef(t::AbstractTensorMap) = isposdef!(copy(t))

# Orthogonal factorizations (mutation for recycling memory):
# only possible if scalar type is floating point
# only correct if Euclidean inner product
#------------------------------------------------------------------------------------------
function leftorth!(t::TensorMap;
const RealOrComplexFloat = Union{AbstractFloat,Complex{<:AbstractFloat}}

function leftorth!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{QR,QRpos,QL,QLpos,SVD,SDD,Polar}=QRpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -321,7 +361,7 @@
return Q, R
end

function leftnull!(t::TensorMap;
function leftnull!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{QR,QRpos,SVD,SDD}=QRpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -360,7 +400,7 @@
return N
end

function rightorth!(t::TensorMap;
function rightorth!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{LQ,LQpos,RQ,RQpos,SVD,SDD,Polar}=LQpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -410,7 +450,7 @@
return L, Q
end

function rightnull!(t::TensorMap;
function rightnull!(t::TensorMap{<:RealOrComplexFloat};
alg::Union{LQ,LQpos,SVD,SDD}=LQpos(),
atol::Real=zero(float(real(scalartype(t)))),
rtol::Real=(alg ∉ (SVD(), SDD())) ? zero(float(real(scalartype(t)))) :
Expand Down Expand Up @@ -476,7 +516,13 @@
#------------------------------#
# Singular value decomposition #
#------------------------------#
function tsvd!(t::TensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
function LinearAlgebra.svdvals!(t::TensorMap{<:RealOrComplexFloat})
return SectorDict(c => LinearAlgebra.svdvals!(b) for (c, b) in blocks(t))
end
LinearAlgebra.svdvals!(t::AdjointTensorMap) = svdvals!(adjoint(t))

Check warning on line 522 in src/tensors/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/factorizations.jl#L522

Added line #L522 was not covered by tests

function tsvd!(t::TensorMap{<:RealOrComplexFloat};
trunc=NoTruncation(), p::Real=2, alg=SDD())
return _tsvd!(t, alg, trunc, p)
end
function tsvd!(t::AdjointTensorMap; trunc=NoTruncation(), p::Real=2, alg=SDD())
Expand All @@ -485,7 +531,8 @@
end

# implementation dispatches on algorithm
function _tsvd!(t, alg::Union{SVD,SDD}, trunc::TruncationScheme, p::Real=2)
function _tsvd!(t::TensorMap{<:RealOrComplexFloat}, alg::Union{SVD,SDD},
trunc::TruncationScheme, p::Real=2)
# early return
if isempty(blocksectors(t))
truncerr = zero(real(scalartype(t)))
Expand Down Expand Up @@ -518,13 +565,17 @@
return SVDdata, dims
end

function _create_svdtensors(t, SVDdata, dims)
function _create_svdtensors(t::TensorMap{<:RealOrComplexFloat}, SVDdata, dims)
T = scalartype(t)
S = spacetype(t)
W = S(dims)
T = float(scalartype(t))
U = similar(t, T, codomain(t) ← W)
Σ = similar(t, real(T), W ← W)
V⁺ = similar(t, T, W ← domain(t))

Tr = real(T)
A = similarstoragetype(t, Tr)
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)

U = similar(t, codomain(t) ← W)
V⁺ = similar(t, W ← domain(t))
for (c, (Uc, Σc, V⁺c)) in SVDdata
r = Base.OneTo(dims[c])
copy!(block(U, c), view(Uc, :, r))
Expand All @@ -534,38 +585,53 @@
return U, Σ, V⁺
end

function _empty_svdtensors(t)
function _empty_svdtensors(t::TensorMap{<:RealOrComplexFloat})
T = scalartype(t)
S = spacetype(t)
I = sectortype(t)
dims = SectorDict{I,Int}()
S = spacetype(t)
W = S(dims)

Tr = real(T)
A = similarstoragetype(t, Tr)
Σ = DiagonalTensorMap{Tr,S,A}(undef, W)

U = similar(t, codomain(t) ← W)
Σ = similar(t, real(scalartype(t)), W ← W)
V⁺ = similar(t, W ← domain(t))
return U, Σ, V⁺
end

#--------------------------#
# Eigenvalue decomposition #
#--------------------------#
LinearAlgebra.eigen!(t::TensorMap) = ishermitian(t) ? eigh!(t) : eig!(t)
function LinearAlgebra.eigen!(t::TensorMap{<:RealOrComplexFloat})
return ishermitian(t) ? eigh!(t) : eig!(t)
end

function LinearAlgebra.eigvals!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
return SectorDict(c => complex(LinearAlgebra.eigvals!(b; kwargs...))
for (c, b) in blocks(t))
end
function LinearAlgebra.eigvals!(t::AdjointTensorMap{<:RealOrComplexFloat}; kwargs...)
return SectorDict(c => conj!(complex(LinearAlgebra.eigvals!(b; kwargs...)))

Check warning on line 616 in src/tensors/factorizations.jl

View check run for this annotation

Codecov / codecov/patch

src/tensors/factorizations.jl#L615-L616

Added lines #L615 - L616 were not covered by tests
for (c, b) in blocks(t))
end

function eigh!(t::TensorMap)
function eigh!(t::TensorMap{<:RealOrComplexFloat})
InnerProductStyle(t) === EuclideanInnerProduct() || throw_invalid_innerproduct(:eigh!)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eigh!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
if length(domain(t)) == 1
W = domain(t)[1]
else
S = spacetype(t)
W = S(dims)
end
T = float(scalartype(t))
V = similar(t, T, domain(t) ← W)
D = similar(t, real(T), W ← W)
W = S(dims)

Tr = real(T)
A = similarstoragetype(t, Tr)
D = DiagonalTensorMap{Tr,S,A}(undef, W)
V = similar(t, domain(t) ← W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eigh!(b)
copy!(block(D, c), Diagonal(values))
Expand All @@ -574,20 +640,20 @@
return D, V
end

function eig!(t::TensorMap; kwargs...)
function eig!(t::TensorMap{<:RealOrComplexFloat}; kwargs...)
domain(t) == codomain(t) ||
throw(SpaceMismatch("`eig!` requires domain and codomain to be the same"))

T = scalartype(t)
I = sectortype(t)
S = spacetype(t)
dims = SectorDict{I,Int}(c => size(b, 1) for (c, b) in blocks(t))
if length(domain(t)) == 1
W = domain(t)[1]
else
S = spacetype(t)
W = S(dims)
end
T = complex(float(scalartype(t)))
V = similar(t, T, domain(t) ← W)
D = similar(t, T, W ← W)
W = S(dims)

Tc = complex(T)
A = similarstoragetype(t, Tc)
D = DiagonalTensorMap{Tc,S,A}(undef, W)
V = similar(t, Tc, domain(t) ← W)
for (c, b) in blocks(t)
values, vectors = MatrixAlgebra.eig!(b; kwargs...)
copy!(block(D, c), Diagonal(values))
Expand Down
Loading