Skip to content

Remove t argument from alternating_update and improve keyword argument handling #67

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
65ede25
Remove t argument from alternating_update and improve keyword argumen…
emstoudenmire Feb 9, 2023
5dff411
Move _compute_nsweeps out of alternating_update
emstoudenmire Feb 15, 2023
0a77f86
Improve process_sweeps
emstoudenmire Feb 15, 2023
b31e7e9
Move TDVPOrder object construction to tdvp.jl
emstoudenmire Feb 16, 2023
c509e74
Inline _get_sweep_generator
emstoudenmire Feb 16, 2023
a5bde68
Rename _extract_tensor to make_local_tensor
emstoudenmire Mar 6, 2023
9a3fc4f
Partial work of extracting tdvp specific code
emstoudenmire Mar 7, 2023
1a2bd65
Remove t argument from solvers
emstoudenmire Mar 7, 2023
954f1ec
Note some possible issues
emstoudenmire Mar 7, 2023
0723a17
Further cleanup of update_step and related
emstoudenmire Mar 7, 2023
a17944e
Simplify contract_solver code
emstoudenmire Mar 7, 2023
28c7749
Remove t arg from custom solvers in tests
emstoudenmire Mar 7, 2023
263a646
Handle kwargs differently
emstoudenmire Mar 7, 2023
a495169
Formatting
emstoudenmire Mar 8, 2023
b3c86fe
Remove time_direction from solver arguments
emstoudenmire Mar 8, 2023
c9a3441
Change TDVP parameter nsweeps to nsteps
emstoudenmire Mar 8, 2023
15aa945
Update time dependent tests
emstoudenmire Mar 10, 2023
c8aefaf
Fix test
emstoudenmire Mar 10, 2023
3669a4a
Remove ITensorNetworks. prefix
emstoudenmire Mar 11, 2023
fbd0f9e
Minor change to restart tests
emstoudenmire Mar 11, 2023
b09c262
Merge branch 'main' into remove_t
emstoudenmire Mar 11, 2023
3960493
Formatting
emstoudenmire Mar 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 28 additions & 56 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
function _compute_nsweeps(t; kwargs...)
time_step::Number = get(kwargs, :time_step, t)
nsweeps::Union{Int,Nothing} = get(kwargs, :nsweeps, nothing)
if isinf(t) && isnothing(nsweeps)
nsweeps = 1
elseif !isnothing(nsweeps) && time_step != t
error("Cannot specify both time_step and nsweeps in alternating_update")
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsweeps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step ≈ t)
error("Time step $time_step not commensurate with total time t=$t")
end
end

return nsweeps
end

function _extend_sweeps_param(param, nsweeps)
if param isa Number
Expand All @@ -27,39 +11,37 @@ function _extend_sweeps_param(param, nsweeps)
return eparam
end

function process_sweeps(; kwargs...)
nsweeps = get(kwargs, :nsweeps, 1)
maxdim = get(kwargs, :maxdim, fill(typemax(Int), nsweeps))
mindim = get(kwargs, :mindim, fill(1, nsweeps))
cutoff = get(kwargs, :cutoff, fill(1E-16, nsweeps))
noise = get(kwargs, :noise, fill(0.0, nsweeps))

function process_sweeps(
nsweeps;
cutoff=fill(1E-16, nsweeps),
maxdim=fill(typemax(Int), nsweeps),
mindim=fill(1, nsweeps),
noise=fill(0.0, nsweeps),
kwargs...,
)
maxdim = _extend_sweeps_param(maxdim, nsweeps)
mindim = _extend_sweeps_param(mindim, nsweeps)
cutoff = _extend_sweeps_param(cutoff, nsweeps)
noise = _extend_sweeps_param(noise, nsweeps)

return (; maxdim, mindim, cutoff, noise)
return maxdim, mindim, cutoff, noise
end

function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
reverse_step = get(kwargs, :reverse_step, true)

nsweeps = _compute_nsweeps(t; kwargs...)
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, kwargs...)

time_start::Number = get(kwargs, :time_start, 0.0)
time_step::Number = get(kwargs, :time_step, t)
order = get(kwargs, :order, 2)
tdvp_order = TDVPOrder(order, Base.Forward)
function alternating_update(
solver,
PH,
psi0::AbstractTTN;
checkdone=nothing,
tdvp_order=TDVPOrder(2, Base.Forward),
outputlevel=0,
time_start=0.0,
time_step=0.0,
nsweeps=1,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
kwargs...,
)
maxdim, mindim, cutoff, noise = process_sweeps(nsweeps; kwargs...)

checkdone = get(kwargs, :checkdone, nothing)
write_when_maxdim_exceeds::Union{Int,Nothing} = get(
kwargs, :write_when_maxdim_exceeds, nothing
)
observer = get(kwargs, :observer!, nothing)
step_observer = get(kwargs, :step_observer!, nothing)
outputlevel::Int = get(kwargs, :outputlevel, 0)

psi = copy(psi0)

Expand Down Expand Up @@ -89,7 +71,7 @@ function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
psi;
kwargs...,
current_time,
reverse_step,
outputlevel,
sweep=sw,
maxdim=maxdim[sw],
mindim=mindim[sw],
Expand Down Expand Up @@ -121,22 +103,14 @@ function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
return psi
end

function alternating_update(solver, H::AbstractTTN, t::Number, psi0::AbstractTTN; kwargs...)
function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
# Permute the indices to have a better memory layout
# and minimize permutations
H = ITensors.permute(H, (linkind, siteinds, linkind))
PH = ProjTTN(H)
return alternating_update(solver, PH, t, psi0; kwargs...)
end

function alternating_update(solver, t::Number, H, psi0::AbstractTTN; kwargs...)
return alternating_update(solver, H, t, psi0; kwargs...)
end

function alternating_update(solver, H, psi0::AbstractTTN, t::Number; kwargs...)
return alternating_update(solver, H, t, psi0; kwargs...)
return alternating_update(solver, PH, psi0; kwargs...)
end

"""
Expand All @@ -158,14 +132,12 @@ each step of the algorithm when optimizing the MPS.
Returns:
* `psi::MPS` - time-evolved MPS
"""
function alternating_update(
solver, Hs::Vector{<:AbstractTTN}, t::Number, psi0::AbstractTTN; kwargs...
)
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
for H in Hs
check_hascommoninds(siteinds, H, psi0)
check_hascommoninds(siteinds, H, psi0')
end
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
PHs = ProjTTNSum(Hs)
return alternating_update(solver, PHs, t, psi0; kwargs...)
return alternating_update(solver, PHs, psi0; kwargs...)
end
20 changes: 7 additions & 13 deletions src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
function contract_solver(; kwargs...)
function solver(PH, t, psi; kws...)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
function contract_solver(PH, psi; kwargs...)
v = ITensor(1.0)
for j in sites(PH)
v *= PH.psi0[j]
end
return solver
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
end

function contract(
Expand Down Expand Up @@ -44,12 +41,9 @@ function contract(
## )
## end

t = Inf
reverse_step = false
PH = ProjTTNApply(tn2, tn1)
psi = alternating_update(
contract_solver(; kwargs...), PH, t, init; nsweeps, reverse_step, kwargs...
)
psi = alternating_update(contract_solver, PH, init; nsweeps, reverse_step, kwargs...)

return psi
end
Expand Down
11 changes: 4 additions & 7 deletions src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
function eigsolve_solver(; kwargs...)
function solver(H, t, init; kws...)
function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
function solver(H, init; kws...)
howmany = 1
which = get(kwargs, :solver_which_eigenvalue, :SR)
which = solver_which_eigenvalue
solver_kwargs = (;
ishermitian=get(kwargs, :ishermitian, true),
tol=get(kwargs, :solver_tol, 1E-14),
Expand All @@ -20,11 +20,8 @@ end
Overload of `ITensors.dmrg`.
"""
function dmrg(H, init::AbstractTTN; kwargs...)
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
reverse_step = false
psi = alternating_update(
eigsolve_solver(; kwargs...), H, t, init; reverse_step, kwargs...
)
psi = alternating_update(eigsolve_solver(; kwargs...), H, init; reverse_step, kwargs...)
return psi
end

Expand Down
5 changes: 2 additions & 3 deletions src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function dmrg_x_solver(PH, t, init; kwargs...)
function dmrg_x_solver(PH, init; kwargs...)
H = contract(PH, ITensor(1.0))
D, U = eigen(H; ishermitian=true)
u = uniqueind(U, H)
Expand All @@ -8,7 +8,6 @@ function dmrg_x_solver(PH, t, init; kwargs...)
end

function dmrg_x(PH, init::AbstractTTN; reverse_step=false, kwargs...)
t = Inf
psi = alternating_update(dmrg_x_solver, PH, t, init; reverse_step, kwargs...)
psi = alternating_update(dmrg_x_solver, PH, init; reverse_step, kwargs...)
return psi
end
4 changes: 1 addition & 3 deletions src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ function linsolve(
)
function linsolve_solver(
P,
t,
x₀;
ishermitian=false,
solver_tol=1E-14,
Expand All @@ -50,9 +49,8 @@ function linsolve(

error("`linsolve` for TTN not yet implemented.")

t = Inf
# TODO: Define `itensornetwork_cache`
# TODO: Define `linsolve_cache`
P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b))
return alternating_update(linsolve_solver, P, t, x₀; reverse_step=false, kwargs...)
return alternating_update(linsolve_solver, P, x₀; reverse_step=false, kwargs...)
end
86 changes: 66 additions & 20 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
function exponentiate_solver(; kwargs...)
function solver(H, t, init; kws...)
function solver(
H,
init;
ishermitian=true,
issymmetric=true,
solver_krylovdim=30,
solver_maxiter=100,
solver_outputlevel=0,
solver_tol=1E-12,
substep,
time_step,
kws...,
)
solver_kwargs = (;
ishermitian=get(kwargs, :ishermitian, true),
issymmetric=get(kwargs, :issymmetric, true),
tol=get(kwargs, :solver_tol, 1E-12),
krylovdim=get(kwargs, :solver_krylovdim, 30),
maxiter=get(kwargs, :solver_maxiter, 100),
verbosity=get(kwargs, :solver_outputlevel, 0),
ishermitian,
issymmetric,
tol=solver_tol,
krylovdim=solver_krylovdim,
maxiter=solver_maxiter,
verbosity=solver_outputlevel,
eager=true,
)
psi, info = exponentiate(H, t, init; solver_kwargs...)

psi, info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, info
end
return solver
end

function applyexp_solver(; kwargs...)
function solver(H, t, init; kws...)
tol_per_unit_time = get(kwargs, :solver_tol, 1E-8)
solver_kwargs = (;
maxiter=get(kwargs, :solver_krylovdim, 30),
outputlevel=get(kwargs, :solver_outputlevel, 0),
)
function solver(
H,
init;
tdvp_order,
solver_krylovdim=30,
solver_outputlevel=0,
solver_tol=1E-8,
substep,
time_step,
kws...,
)
solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(t) * tol_per_unit_time
psi, info = applyexp(H, t, init; tol, solver_kwargs..., kws...)
tol = abs(time_step) * tol_per_unit_time
psi, info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
return psi, info
end
return solver
Expand All @@ -42,22 +62,48 @@ function tdvp_solver(; solver_backend="exponentiate", kwargs...)
end
end

function tdvp(solver, H, t::Number, init::AbstractTTN; kwargs...)
return alternating_update(solver, H, t, init; kwargs...)
function _compute_nsweeps(nsteps, t, time_step)
nsweeps = 1
if !isnothing(nsteps) && time_step != t
error("Cannot specify both nsteps and time_step in tdvp")
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsteps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step ≈ t)
error("Time step $time_step not commensurate with total time t=$t")
end
end
return nsweeps
end

function tdvp(
solver,
H,
t::Number,
init::AbstractTTN;
time_step::Number=t,
nsteps=nothing,
order::Integer=2,
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step)
tdvp_order = TDVPOrder(order, Base.Forward)
return alternating_update(solver, H, init; nsweeps, tdvp_order, time_step, kwargs...)
end

"""
tdvp(H::TTN, t::Number, psi0::TTN; kwargs...)

Use the time dependent variational principle (TDVP) algorithm
to compute `exp(H*t)*psi0` using an efficient algorithm based
to approximately compute `exp(H*t)*psi0` using an efficient algorithm based
on alternating optimization of the state tensors and local Krylov
exponentiation of H.
exponentiation of H. The time parameter `t` can be a real or complex number.

Returns:
* `psi` - time-evolved state

Optional keyword arguments:
* `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run.
* `nsteps::Integer` - evolve by the requested total time `t` by performing `nsteps` of the TDVP algorithm. More steps can result in more accurate results but require more computational time to run. (Note that only one of the `time_step` or `nsteps` parameters can be provided, not both.)
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
Expand Down
13 changes: 9 additions & 4 deletions src/treetensornetworks/solvers/tdvporder.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@

struct TDVPOrder{order,direction} end

TDVPOrder(order::Int, direction::Base.Ordering) = TDVPOrder{order,direction}()

orderings(::TDVPOrder) = error("Not implemented")
directions(::TDVPOrder) = error("Not implemented")
sub_time_steps(::TDVPOrder) = error("Not implemented")

function orderings(::TDVPOrder{1,direction}) where {direction}
function directions(::TDVPOrder{1,direction}) where {direction}
return [direction, Base.ReverseOrdering(direction)]
end
sub_time_steps(::TDVPOrder{1}) = [1.0, 0.0]

function orderings(::TDVPOrder{2,direction}) where {direction}
function directions(::TDVPOrder{2,direction}) where {direction}
return [direction, Base.ReverseOrdering(direction)]
end
sub_time_steps(::TDVPOrder{2}) = [1.0 / 2.0, 1.0 / 2.0]

function orderings(::TDVPOrder{4,direction}) where {direction}
#
# TODO: possible bug, shouldn't length(directions) here equal
# length(sub_time_steps) below? (I.e. both return a length 6 vector?)
#
function directions(::TDVPOrder{4,direction}) where {direction}
return [direction, Base.ReverseOrdering(direction)]
end
function sub_time_steps(::TDVPOrder{4})
Expand Down
Loading