Skip to content

Use Observer for output #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ IsApprox = "0.1.7"
IterTools = "1.4.0"
KrylovKit = "0.6.0"
NamedGraphs = "0.1.11"
Observers = "0.0.8"
Observers = "0.2"
Requires = "1.3"
SimpleTraits = "0.9"
SparseArrayKit = "0.2.1"
Expand Down
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ using IterTools
using KrylovKit: KrylovKit
using NamedGraphs
using Observers
using Observers.DataFrames: select!
using Printf
using Requires
using SimpleTraits
Expand Down
60 changes: 29 additions & 31 deletions src/treetensornetworks/solvers/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,66 +26,64 @@ function process_sweeps(
return maxdim, mindim, cutoff, noise, kwargs
end

function sweep_printer(; outputlevel, psi, sweep, sw_time)
if outputlevel >= 1
print("After sweep ", sweep, ":")
print(" maxlinkdim=", maxlinkdim(psi))
print(" cpu_time=", round(sw_time; digits=3))
println()
flush(stdout)
end
end

function alternating_update(
solver,
PH,
psi0::AbstractTTN;
checkdone=nothing,
outputlevel=0,
nsweeps=1,
checkdone=(; kws...) -> false,
outputlevel::Integer=0,
nsweeps::Integer=1,
(sweep_observer!)=observer(),
sweep_printer=sweep_printer,
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
kwargs...,
)
maxdim, mindim, cutoff, noise, kwargs = process_sweeps(nsweeps; kwargs...)

step_observer = get(kwargs, :step_observer!, nothing)

psi = copy(psi0)

info = nothing
for sw in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sw] > write_when_maxdim_exceeds
insert_function!(sweep_observer!, "sweep_printer" => sweep_printer)

for sweep in 1:nsweeps
if !isnothing(write_when_maxdim_exceeds) && maxdim[sweep] > write_when_maxdim_exceeds
if outputlevel >= 2
println(
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim(sweeps, sw) = $(maxdim(sweeps, sw)), writing environment tensors to disk",
"write_when_maxdim_exceeds = $write_when_maxdim_exceeds and maxdim[sweep] = $(maxdim[sweep]), writing environment tensors to disk",
)
end
PH = disk(PH)
end

sw_time = @elapsed begin
psi, PH, info = update_step(
psi, PH = update_step(
solver,
PH,
psi;
outputlevel,
sweep=sw,
maxdim=maxdim[sw],
mindim=mindim[sw],
cutoff=cutoff[sw],
noise=noise[sw],
sweep,
maxdim=maxdim[sweep],
mindim=mindim[sweep],
cutoff=cutoff[sweep],
noise=noise[sweep],
kwargs...,
)
end

update!(step_observer; psi, sweep=sw, outputlevel)

if outputlevel >= 1
print("After sweep ", sw, ":")
print(" maxlinkdim=", maxlinkdim(psi))
@printf(" maxerr=%.2E", info.maxtruncerr)
#print(" current_time=", round(current_time; digits=3))
print(" time=", round(sw_time; digits=3))
println()
flush(stdout)
end
update!(sweep_observer!; psi, sweep, sw_time, outputlevel)

isdone = false
if !isnothing(checkdone)
isdone = checkdone(; psi, sweep=sw, outputlevel, kwargs...)
end
isdone && break
checkdone(; psi, sweep, outputlevel, kwargs...) && break
end
select!(sweep_observer!, Observers.DataFrames.Not("sweep_printer")) # remove sweep_printer
return psi
end

Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/contract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function contract_solver(PH, psi; kwargs...)
v *= PH.psi0[j]
end
Hpsi0 = contract(PH, v)
return Hpsi0, nothing
return Hpsi0, NamedTuple()
end

function contract(
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/dmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
)
vals, vecs, info = eigsolve(H, init, howmany, which; solver_kwargs...)
psi = vecs[1]
return psi, info
return psi, (; solver_info=info, energies=vals)
end
return solver
end
Expand Down
3 changes: 2 additions & 1 deletion src/treetensornetworks/solvers/dmrg_x.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ function dmrg_x_solver(PH, init; kwargs...)
u = uniqueind(U, H)
max_overlap, max_ind = findmax(abs, array(dag(init) * U))
U_max = U * dag(onehot(u => max_ind))
return U_max, nothing
# TODO: improve this to return the energy estimate too
return U_max, NamedTuple()
end

function dmrg_x(PH, init::AbstractTTN; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/treetensornetworks/solvers/linsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function linsolve(
)
b = dag(only(proj_mps(P)))
x, info = KrylovKit.linsolve(P, b, x₀, a₀, a₁; solver_kwargs...)
return x, nothing
return x, NamedTuple()
end

error("`linsolve` for TTN not yet implemented.")
Expand Down
71 changes: 48 additions & 23 deletions src/treetensornetworks/solvers/tdvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ function exponentiate_solver(; kwargs...)
eager=true,
)

psi, info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, info
psi, exp_info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
return psi, (; info=exp_info)
end
return solver
end
Expand All @@ -45,31 +45,25 @@ function applyexp_solver(; kwargs...)

#applyexp tol is absolute, compute from tol_per_unit_time:
tol = abs(time_step) * tol_per_unit_time
psi, info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
return psi, info
psi, exp_info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
return psi, (; info=exp_info)
end
return solver
end

function tdvp_solver(; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
return exponentiate_solver(; kwargs...)
elseif solver_backend == "applyexp"
return applyexp_solver(; kwargs...)
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
end

function _compute_nsweeps(nsteps, t, time_step)
function _compute_nsweeps(nsteps, t, time_step, order)
nsweeps_per_step = order / 2
nsweeps = 1
if !isnothing(nsteps) && time_step != t
error("Cannot specify both nsteps and time_step in tdvp")
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsteps)
nsweeps = convert(Int, ceil(abs(t / time_step)))
if !(nsweeps * time_step ≈ t)
nsweeps = convert(Int, nsweeps_per_step * ceil(abs(t / time_step)))
if !(nsweeps / nsweeps_per_step * time_step ≈ t)
println(
"Time that will be reached = nsweeps/nsweeps_per_step * time_step = ",
nsweeps / nsweeps_per_step * time_step,
)
println("Requested total time t = ", t)
error("Time step $time_step not commensurate with total time t=$t")
end
end
Expand Down Expand Up @@ -118,11 +112,33 @@ function tdvp(
nsite=2,
nsteps=nothing,
order::Integer=2,
(sweep_observer!)=observer(),
kwargs...,
)
nsweeps = _compute_nsweeps(nsteps, t, time_step)
nsweeps = _compute_nsweeps(nsteps, t, time_step, order)
sweep_regions = tdvp_sweep(order, nsite, time_step, init; kwargs...)
return alternating_update(solver, H, init; nsweeps, sweep_regions, nsite, kwargs...)

function sweep_time_printer(; outputlevel, sweep, kwargs...)
if outputlevel >= 1
sweeps_per_step = order ÷ 2
if sweep % sweeps_per_step == 0
current_time = (sweep / sweeps_per_step) * time_step
println("Current time (sweep $sweep) = ", round(current_time; digits=3))
end
end
return nothing
end

insert_function!(sweep_observer!, "sweep_time_printer" => sweep_time_printer)

psi = alternating_update(
solver, H, init; nsweeps, sweep_observer!, sweep_regions, nsite, kwargs...
)

# remove sweep_time_printer from sweep_observer!
select!(sweep_observer!, Observers.DataFrames.Not("sweep_time_printer"))

return psi
end

"""
Expand All @@ -143,6 +159,15 @@ Optional keyword arguments:
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations
"""
function tdvp(H, t::Number, init::AbstractTTN; kwargs...)
return tdvp(tdvp_solver(; kwargs...), H, t, init; kwargs...)
function tdvp(H, t::Number, init::AbstractTTN; solver_backend="exponentiate", kwargs...)
if solver_backend == "exponentiate"
solver = exponentiate_solver
elseif solver_backend == "applyexp"
solver = applyexp_solver
else
error(
"solver_backend=$solver_backend not recognized (options are \"applyexp\" or \"exponentiate\")",
)
end
return tdvp(solver(; kwargs...), H, t, init; kwargs...)
end
Loading