Skip to content

Commit

Permalink
Use Observer for output of alternating_update (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
emstoudenmire authored Jul 11, 2023
1 parent 0666c27 commit 1476e5a
Show file tree
Hide file tree
Showing 12 changed files with 268 additions and 173 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ IsApprox = "0.1.7"
IterTools = "1.4.0"
KrylovKit = "0.6.0"
NamedGraphs = "0.1.11"
Observers = "0.0.8"
Observers = "0.2"
Requires = "1.3"
SimpleTraits = "0.9"
SparseArrayKit = "0.2.1"
Expand Down
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using IterTools
using KrylovKit: KrylovKit
using NamedGraphs
using Observers
using Observers.DataFrames: select!
using Printf
using Requires
using SimpleTraits
Expand Down
60 changes: 29 additions & 31 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,66 +26,64 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" cpu_time=", round(sw_time; digits=3))
println()
flush(stdout)
end
end

function alternating_update(
solver,
PH,
psi0::AbstractTTN;
checkdone=nothing,
outputlevel=0,
nsweeps=1,
checkdone=(; kws...) -> false,
outputlevel::Integer=0,
nsweeps::Integer=1,
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

step_observer = get(kwargs, :step_observer!, nothing)

psi = copy(psi0)

info = nothing
for sw in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sw] > write_when_maxdim_exceeds
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim(sweeps, sw) = $(maxdim(sweeps, sw)), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
end

sw_time = @elapsed begin
psi, PH, info = update_step(
psi, PH = update_step(
solver,
PH,
psi;
outputlevel,
sweep=sw,
maxdim=maxdim[sw],
mindim=mindim[sw],
cutoff=cutoff[sw],
noise=noise[sw],
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
kwargs...,
)
end

update!(step_observer; psi, sweep=sw, outputlevel)

if outputlevel >= 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
@printf(" maxerr=%.2E", info.maxtruncerr)
#print(" current_time=", round(current_time; digits=3))
print(" time=", round(sw_time; digits=3))
println()
flush(stdout)
end
update!(sweep_observer!; psi, sweep, sw_time, outputlevel)

isdone = false
if !isnothing(checkdone)
isdone = checkdone(; psi, sweep=sw, outputlevel, kwargs...)
end
isdone && break
checkdone(; psi, sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
return psi
end

Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function contract_solver(PH, psi; kwargs...)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
return Hpsi0, NamedTuple()
end

function contract(
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
)
vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...)
psi = vecs[1]
return psi, info
return psi, (; solver_info=info, energies=vals)
end
return solver
end
Expand Down
3 changes: 2 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ function dmrg_x_solver(PH, init; kwargs...)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
return U_max, nothing
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
end

function dmrg_x(PH, init::AbstractTTN; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function linsolve(
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, nothing
return x, NamedTuple()
end

error("`linsolve` for TTN not yet implemented.")
Expand Down
71 changes: 48 additions & 23 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ function exponentiate_solver(; kwargs...)
eager=true,
)

psi, info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, info
psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, (; info=exp_info)
end
return solver
end
Expand All @@ -45,31 +45,25 @@ function applyexp_solver(; kwargs...)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
return psi, info
psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
return psi, (; info=exp_info)
end
return solver
end

function tdvp_solver(; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
return exponentiate_solver(; kwargs...)
elseif solver_backend == "applyexp"
return applyexp_solver(; kwargs...)
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
end

function _compute_nsweeps(nsteps, t, time_step)
function _compute_nsweeps(nsteps, t, time_step, order)
nsweeps_per_step = order / 2
nsweeps = 1
if !isnothing(nsteps) && time_step != t
error("Cannot specify both nsteps and time_step in tdvp")
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsteps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step t)
nsweeps = convert(Int, nsweeps_per_step * ceil(abs(t / time_step)))
if !(nsweeps / nsweeps_per_step * time_step t)
println(
"Time that will be reached = nsweeps/nsweeps_per_step * time_step = ",
nsweeps / nsweeps_per_step * time_step,
)
println("Requested total time t = ", t)
error("Time step $time_step not commensurate with total time t=$t")
end
end
Expand Down Expand Up @@ -118,11 +112,33 @@ function tdvp(
nsite=2,
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
return alternating_update(solver, H, init; nsweeps, sweep_regions, nsite, kwargs...)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
if outputlevel >= 1
sweeps_per_step = order ÷ 2
if sweep % sweeps_per_step == 0
current_time = (sweep / sweeps_per_step) * time_step
println("Current time (sweep $sweep) = ", round(current_time; digits=3))
end
end
return nothing
end

insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer)

psi = alternating_update(
solver, H, init; nsweeps, sweep_observer!, sweep_regions, nsite, kwargs...
)

# remove sweep_time_printer from sweep_observer!
select!(sweep_observer!, Observers.DataFrames.Not("sweep_time_printer"))

return psi
end

"""
Expand All @@ -143,6 +159,15 @@ Optional keyword arguments:
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, init; kwargs...)
function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
solver = exponentiate_solver
elseif solver_backend == "applyexp"
solver = applyexp_solver
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
return tdvp(solver(; kwargs...), H, t, init; kwargs...)
end
Loading

0 comments on commit 1476e5a

Please sign in to comment.