Skip to content

Rename tdvp Functions #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ include(joinpath("treetensornetworks", "solvers", "solver_utils.jl"))
include(joinpath("treetensornetworks", "solvers", "applyexp.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvporder.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvpinfo.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp_step.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp_generic.jl"))
include(joinpath("treetensornetworks", "solvers", "update_step.jl"))
include(joinpath("treetensornetworks", "solvers", "alternating_update.jl"))
include(joinpath("treetensornetworks", "solvers", "tdvp.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg.jl"))
include(joinpath("treetensornetworks", "solvers", "dmrg_x.jl"))
Expand Down
2 changes: 2 additions & 0 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ export AbstractITensorNetwork,
group_terms,
# tebd.jl
tebd,
# treetensornetwork/opsum_to_ttn.jl
mpo,
# treetensornetwork/solvers.jl
TimeDependentSum,
dmrg_x,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
function _tdvp_compute_nsweeps(t; kwargs...)
function _compute_nsweeps(t; kwargs...)
time_step::Number = get(kwargs, :time_step, t)
nsweeps::Union{Int,Nothing} = get(kwargs, :nsweeps, nothing)
if isinf(t) && isnothing(nsweeps)
nsweeps = 1
elseif !isnothing(nsweeps) && time_step != t
error("Cannot specify both time_step and nsweeps in tdvp")
error("Cannot specify both time_step and nsweeps in alternating_update")
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsweeps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step ≈ t)
Expand Down Expand Up @@ -42,10 +42,10 @@ function process_sweeps(; kwargs...)
return (; maxdim, mindim, cutoff, noise)
end

function tdvp(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
reverse_step = get(kwargs, :reverse_step, true)

nsweeps = _tdvp_compute_nsweeps(t; kwargs...)
nsweeps = _compute_nsweeps(t; kwargs...)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, kwargs...)

time_start::Number = get(kwargs, :time_start, 0.0)
Expand Down Expand Up @@ -81,7 +81,7 @@ function tdvp(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
end

sw_time = @elapsed begin
psi, PH, info = tdvp_step(
psi, PH, info = update_step(
tdvp_order,
solver,
PH,
Expand Down Expand Up @@ -121,39 +121,22 @@ function tdvp(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
return psi
end

"""
tdvp(H::MPS,psi0::MPO,t::Number; kwargs...)
tdvp(H::TTN,psi0::TTN,t::Number; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(t*H)*psi0` using an efficient algorithm based
on alternating optimization of the state tensors and local Krylov
exponentiation of H.

Returns:
* `psi` - time-evolved state

Optional keyword arguments:
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(solver, H::AbstractTTN, t::Number, psi0::AbstractTTN; kwargs...)
function alternating_update(solver, H::AbstractTTN, t::Number, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return tdvp(solver, PH, t, psi0; kwargs...)
return alternating_update(solver, PH, t, psi0; kwargs...)
end

function tdvp(solver, t::Number, H, psi0::AbstractTTN; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
function alternating_update(solver, t::Number, H, psi0::AbstractTTN; kwargs...)
return alternating_update(solver, H, t, psi0; kwargs...)
end

function tdvp(solver, H, psi0::AbstractTTN, t::Number; kwargs...)
return tdvp(solver, H, t, psi0; kwargs...)
function alternating_update(solver, H, psi0::AbstractTTN, t::Number; kwargs...)
return alternating_update(solver, H, t, psi0; kwargs...)
end

"""
Expand All @@ -175,12 +158,14 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function tdvp(solver, Hs::Vector{<:AbstractTTN}, t::Number, psi0::AbstractTTN; kwargs...)
function alternating_update(
solver, Hs::Vector{<:AbstractTTN}, t::Number, psi0::AbstractTTN; kwargs...
)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return tdvp(solver, PHs, t, psi0; kwargs...)
return alternating_update(solver, PHs, t, psi0; kwargs...)
end
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ function contract(
t = Inf
reverse_step = false
PH = ProjTTNApply(tn2, tn1)
psi = tdvp(contract_solver(; kwargs...), PH, t, init; nsweeps, reverse_step, kwargs...)
psi = alternating_update(
contract_solver(; kwargs...), PH, t, init; nsweeps, reverse_step, kwargs...
)

return psi
end
Expand Down
4 changes: 3 additions & 1 deletion src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ Overload of `ITensors.dmrg`.
function dmrg(H, init::AbstractTTN; kwargs...)
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
reverse_step = false
psi = tdvp(eigsolve_solver(; kwargs...), H, t, init; reverse_step, kwargs...)
psi = alternating_update(
eigsolve_solver(; kwargs...), H, t, init; reverse_step, kwargs...
)
return psi
end

Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ end

function dmrg_x(PH, init::AbstractTTN; reverse_step=false, kwargs...)
t = Inf
psi = tdvp(dmrg_x_solver, PH, t, init; reverse_step, kwargs...)
psi = alternating_update(dmrg_x_solver, PH, t, init; reverse_step, kwargs...)
return psi
end
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ function linsolve(
# TODO: Define `itensornetwork_cache`
# TODO: Define `linsolve_cache`
P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b))
return tdvp(linsolve_solver, P, t, x₀; reverse_step=false, kwargs...)
return alternating_update(linsolve_solver, P, t, x₀; reverse_step=false, kwargs...)
end
35 changes: 23 additions & 12 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,38 @@ function applyexp_solver(; kwargs...)
return solver
end

function tdvp_solver(; kwargs...)
solver_backend = get(kwargs, :solver_backend, "applyexp")
if solver_backend == "applyexp"
return applyexp_solver(; kwargs...)
elseif solver_backend == "exponentiate"
function tdvp_solver(; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
return exponentiate_solver(; kwargs...)
elseif solver_backend == "applyexp"
return applyexp_solver(; kwargs...)
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
end

function tdvp(H, t::Number, init::AbstractTTN; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, init; kwargs...)
function tdvp(solver, H, t::Number, init::AbstractTTN; kwargs...)
return alternating_update(solver, H, t, init; kwargs...)
end

function tdvp(t::Number, H, init::AbstractTTN; kwargs...)
return tdvp(H, t, init; kwargs...)
end
"""
tdvp(H::TTN, t::Number, psi0::TTN; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(H*t)*psi0` using an efficient algorithm based
on alternating optimization of the state tensors and local Krylov
exponentiation of H.

Returns:
* `psi` - time-evolved state

function tdvp(H, init::AbstractTTN, t::Number; kwargs...)
return tdvp(H, t, init; kwargs...)
Optional keyword arguments:
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, init; kwargs...)
end
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function tdvp_step(
function update_step(
order::TDVPOrder,
solver,
PH,
Expand All @@ -12,7 +12,7 @@ function tdvp_step(
sub_time_steps *= time_step
info = nothing
for substep in 1:length(sub_time_steps)
psi, PH, info = tdvp_sweep(
psi, PH, info = update_sweep(
orderings[substep], solver, PH, sub_time_steps[substep], psi; current_time, kwargs...
)
current_time += sub_time_steps[substep]
Expand All @@ -34,14 +34,14 @@ function _get_sweep_generator(kwargs)
return error("Unsupported value $nsite for nsite keyword argument.")
end

function tdvp_sweep(
function update_sweep(
direction::Base.Ordering, solver, PH, time_step::Number, psi::AbstractTTN; kwargs...
)
PH = copy(PH)
psi = copy(psi)
if nv(psi) == 1
error(
"`tdvp` currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.",
"`alternating_update` currently does not support system sizes of 1. You can diagonalize the MPO tensor directly with tools like `LinearAlgebra.eigen`, `KrylovKit.exponentiate`, etc.",
)
end
sweep_generator = _get_sweep_generator(kwargs)
Expand All @@ -64,7 +64,7 @@ function tdvp_sweep(
for sweep_step in sweep_generator(
direction, underlying_graph(PH), root_vertex, reverse_step; state=psi, kwargs...
)
psi, PH, current_time, maxtruncerr, spec, info = tdvp_local_update(
psi, PH, current_time, maxtruncerr, spec, info = local_update(
solver,
PH,
psi,
Expand Down Expand Up @@ -160,7 +160,7 @@ function _insert_tensor(psi::AbstractTTN, phi::ITensor, e::NamedEdge; kwargs...)
return psi, nothing
end

function tdvp_local_update(
function local_update(
solver,
PH,
psi,
Expand Down
19 changes: 4 additions & 15 deletions test/test_treetensornetworks/test_solvers/test_tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,9 @@ using Test
# Time evolve forward:
ψ1 = tdvp(H, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1)

@test ψ1 ≈ tdvp(-0.1im, H, ψ0; nsweeps=1, cutoff, nsite=1)
@test ψ1 ≈ tdvp(H, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1)
#Different backend solvers, default solver_backend = "applyexp"
ψ1_exponentiate_backend = tdvp(
H, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1, solver_backend="exponentiate"
H, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1, solver_backend="exponentiate"
)
@test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7

Expand Down Expand Up @@ -81,12 +79,9 @@ using Test

ψ1 = tdvp(Hs, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1)

@test ψ1 ≈ tdvp(-0.1im, Hs, ψ0; nsweeps=1, cutoff, nsite=1)
@test ψ1 ≈ tdvp(Hs, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1)

#Different backend solvers, default solver_backend = "applyexp"
ψ1_exponentiate_backend = tdvp(
Hs, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1, solver_backend="exponentiate"
Hs, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1, solver_backend="exponentiate"
)
@test ψ1 ≈ ψ1_exponentiate_backend rtol = 1e-7

Expand Down Expand Up @@ -477,9 +472,6 @@ end
# Time evolve forward:
ψ1 = tdvp(H, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1)

@test ψ1 ≈ tdvp(-0.1im, H, ψ0; nsweeps=1, cutoff, nsite=1)
@test ψ1 ≈ tdvp(H, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1)

@test norm(ψ1) ≈ 1.0

## Should lose fidelity:
Expand Down Expand Up @@ -522,9 +514,6 @@ end

ψ1 = tdvp(Hs, -0.1im, ψ0; nsweeps=1, cutoff, nsite=1)

@test ψ1 ≈ tdvp(-0.1im, Hs, ψ0; nsweeps=1, cutoff, nsite=1)
@test ψ1 ≈ tdvp(Hs, ψ0, -0.1im; nsweeps=1, cutoff, nsite=1)

@test norm(ψ1) ≈ 1.0

## Should lose fidelity:
Expand Down Expand Up @@ -566,8 +555,8 @@ end

ψ1 = tdvp(solver, H, -0.1im, ψ0; cutoff, nsite=1)

@test ψ1 ≈ tdvp(solver, -0.1im, H, ψ0; cutoff, nsite=1)
@test ψ1 ≈ tdvp(solver, H, ψ0, -0.1im; cutoff, nsite=1)
#@test ψ1 ≈ tdvp(solver, -0.1im, H, ψ0; cutoff, nsite=1)
#@test ψ1 ≈ tdvp(solver, H, ψ0, -0.1im; cutoff, nsite=1)

@test norm(ψ1) ≈ 1.0

Expand Down