Skip to content

Sweeping algorithms for tree tensor networks #44

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Expand Down
6 changes: 5 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using IsApprox
using ITensors
using ITensors.ContractionSequenceOptimization
using ITensors.ITensorVisualizationCore
using IterTools
using KrylovKit: KrylovKit
using NamedGraphs
using Observers
Expand Down Expand Up @@ -87,6 +88,8 @@ include(joinpath("treetensornetworks", "abstractprojttno.jl"))
include(joinpath("treetensornetworks", "projttno.jl"))
include(joinpath("treetensornetworks", "projttnosum.jl"))
include(joinpath("treetensornetworks", "projttno_apply.jl"))
# Compatibility of ITensors.MPS/MPO with tree sweeping routines
include(joinpath("treetensornetworks", "solvers", "tree_patch.jl"))
# Compatibility of ITensor observer and Observers
# TODO: Delete this
include(joinpath("treetensornetworks", "solvers", "update_observer.jl"))
Expand All @@ -103,10 +106,11 @@ include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
include(joinpath("treetensornetworks", "solvers", "projmpo_apply.jl"))
include(joinpath("treetensornetworks", "solvers", "contract_mpo_mps.jl"))
include(joinpath("treetensornetworks", "solvers", "contract_operator_state.jl"))
include(joinpath("treetensornetworks", "solvers", "projmps2.jl"))
include(joinpath("treetensornetworks", "solvers", "projmpo_mps2.jl"))
include(joinpath("treetensornetworks", "solvers", "linsolve.jl"))
include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))

include("exports.jl")

Expand Down
52 changes: 0 additions & 52 deletions src/treetensornetworks/solvers/contract_mpo_mps.jl

This file was deleted.

62 changes: 62 additions & 0 deletions src/treetensornetworks/solvers/contract_operator_state.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
function contract_solver(; kwargs...)
function solver(PH, t, psi; kws...)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
end
return solver
end

function ITensors.contract(
::ITensors.Algorithm"fit",
A::IsTreeOperator,
psi0::ST;
init_state=psi0,
nsweeps=1,
kwargs...,
)::ST where {ST<:IsTreeState}
n = nv(A)
n != nv(psi0) && throw(
DimensionMismatch("Number of sites operator ($n) and state ($(nv(psi0))) do not match"),
)
if n == 1
v = only(vertices(psi0))
return ST([A[v] * psi0[v]])
end

check_hascommoninds(siteinds, A, psi0)

# In case A and psi0 have the same link indices
A = sim(linkinds, A)

# Fix site and link inds of init_state
init_state = deepcopy(init_state)
init_state = sim(linkinds, init_state)
for v in vertices(psi0)
replaceinds!(
init_state[v], siteinds(init_state, v), uniqueinds(siteinds(A, v), siteinds(psi0, v))
)
end

t = Inf
reverse_step = false
PH = proj_operator_apply(psi0, A)
psi = tdvp(
contract_solver(; kwargs...), PH, t, init_state; nsweeps, reverse_step, kwargs...
)

return psi
end

# extra ITensors overloads for tree tensor networks
function ITensors.contract(A::TTNO, ψ::TTNS; alg="fit", kwargs...)
return contract(ITensors.Algorithm(alg), A, ψ; kwargs...)
end

function ITensors.apply(A::TTNO, ψ::TTNS; kwargs...)
Aψ = contract(A, ψ; kwargs...)
return replaceprime(Aψ, 1 => 0)
end
4 changes: 2 additions & 2 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ function eigsolve_solver(; kwargs...)
return solver
end

function dmrg(H, psi0::MPS; kwargs...)
function dmrg(H, psi0::IsTreeState; kwargs...)
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
reverse_step = false
psi = tdvp(eigsolve_solver(; kwargs...), H, t, psi0; reverse_step, kwargs...)
return psi
end

# Alias for DMRG
function eigsolve(H, psi0::MPS; kwargs...)
function eigsolve(H, psi0::IsTreeState; kwargs...)
return dmrg(H, psi0; kwargs...)
end
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function dmrg_x_solver(PH, t, psi0; kwargs...)
return U_max, nothing
end

function dmrg_x(PH, psi0::MPS; reverse_step=false, kwargs...)
function dmrg_x(PH, psi0::IsTreeState; reverse_step=false, kwargs...)
t = Inf
psi = tdvp(dmrg_x_solver, PH, t, psi0; reverse_step, kwargs...)
return psi
Expand Down
2 changes: 2 additions & 0 deletions src/treetensornetworks/solvers/projmpo_mps2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,5 @@ end
contract(P::ProjMPO_MPS2, v::ITensor) = contract(P.PH, v)

proj_mps(P::ProjMPO_MPS2) = [proj_mps(m) for m in P.Ms]

underlying_graph(P::ProjMPO_MPS2) = chain_lattice_graph(length(P.PH.H)) # tree patch
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/solver_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct TimeDependentSum{S,T}
f::Vector{S}
H0::T
end
TimeDependentSum(f::Vector, H0::ProjMPOSum) = TimeDependentSum(f, H0.pm)
TimeDependentSum(f::Vector, H0::IsTreeProjOperatorSum) = TimeDependentSum(f, H0.pm)
Base.length(H::TimeDependentSum) = length(H.f)

function Base.:*(c::Number, H::TimeDependentSum)
Expand Down
6 changes: 3 additions & 3 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ function tdvp_solver(; kwargs...)
end
end

function tdvp(H, t::Number, psi0::MPS; kwargs...)
function tdvp(H, t::Number, psi0::IsTreeState; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, psi0; kwargs...)
end

function tdvp(t::Number, H, psi0::MPS; kwargs...)
function tdvp(t::Number, H, psi0::IsTreeState; kwargs...)
return tdvp(H, t, psi0; kwargs...)
end

function tdvp(H, psi0::MPS, t::Number; kwargs...)
function tdvp(H, psi0::IsTreeState, t::Number; kwargs...)
return tdvp(H, t, psi0; kwargs...)
end
24 changes: 13 additions & 11 deletions src/treetensornetworks/solvers/tdvp_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ function process_sweeps(; kwargs...)
return (; maxdim, mindim, cutoff, noise)
end

function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
function tdvp(solver, PH, t::Number, psi0::IsTreeState; kwargs...)
reverse_step = get(kwargs, :reverse_step, true)

nsweeps = _tdvp_compute_nsweeps(t; kwargs...)
Expand Down Expand Up @@ -124,37 +124,37 @@ function tdvp(solver, PH, t::Number, psi0::MPS; kwargs...)
end

"""
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
tdvp(H::MPO,psi0::MPS,t::Number; kwargs...)
tdvp(H::MPS,psi0::MPO,t::Number; kwargs...)
tdvp(H::TTNS,psi0::TTNO,t::Number; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
on alternating optimization of the MPS tensors and local Krylov
on alternating optimization of the state tensors and local Krylov
exponentiation of H.

Returns:
* `psi::MPS` - time-evolved MPS
* `psi` - time-evolved state

Optional keyword arguments:
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(solver, H::MPO, t::Number, psi0::MPS; kwargs...)
function tdvp(solver, H::IsTreeOperator, t::Number, psi0::IsTreeState; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjMPO(H)
PH = proj_operator(H)
return tdvp(solver, PH, t, psi0; kwargs...)
end

function tdvp(solver, t::Number, H, psi0::MPS; kwargs...)
function tdvp(solver, t::Number, H, psi0::IsTreeState; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
end

function tdvp(solver, H, psi0::MPS, t::Number; kwargs...)
function tdvp(solver, H, psi0::IsTreeState, t::Number; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
end

Expand All @@ -177,12 +177,14 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function tdvp(solver, Hs::Vector{MPO}, t::Number, psi0::MPS; kwargs...)
function tdvp(
solver, Hs::Vector{<:IsTreeOperator}, t::Number, psi0::IsTreeState; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjMPOSum(Hs)
PHs = proj_operator_sum(Hs)
return tdvp(solver, PHs, t, psi0; kwargs...)
end
Loading