Skip to content

Commit ea7e9b8

Browse files
Remove t argument from alternating_update and improve keyword argument handling
1 parent 536faaa commit ea7e9b8

File tree

10 files changed

+226
-210
lines changed

10 files changed

+226
-210
lines changed

src/treetensornetworks/solvers/alternating_update.jl

Lines changed: 28 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,3 @@
1-
function _compute_nsweeps(t; kwargs...)
2-
time_step::Number = get(kwargs, :time_step, t)
3-
nsweeps::Union{Int,Nothing} = get(kwargs, :nsweeps, nothing)
4-
if isinf(t) && isnothing(nsweeps)
5-
nsweeps = 1
6-
elseif !isnothing(nsweeps) && time_step != t
7-
error("Cannot specify both time_step and nsweeps in alternating_update")
8-
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsweeps)
9-
nsweeps = convert(Int, ceil(abs(t / time_step)))
10-
if !(nsweeps * time_step t)
11-
error("Time step $time_step not commensurate with total time t=$t")
12-
end
13-
end
14-
15-
return nsweeps
16-
end
171

182
function _extend_sweeps_param(param, nsweeps)
193
if param isa Number
@@ -27,39 +11,37 @@ function _extend_sweeps_param(param, nsweeps)
2711
return eparam
2812
end
2913

30-
function process_sweeps(; kwargs...)
31-
nsweeps = get(kwargs, :nsweeps, 1)
32-
maxdim = get(kwargs, :maxdim, fill(typemax(Int), nsweeps))
33-
mindim = get(kwargs, :mindim, fill(1, nsweeps))
34-
cutoff = get(kwargs, :cutoff, fill(1E-16, nsweeps))
35-
noise = get(kwargs, :noise, fill(0.0, nsweeps))
36-
14+
function process_sweeps(
15+
nsweeps;
16+
cutoff=fill(1E-16, nsweeps),
17+
maxdim=fill(typemax(Int), nsweeps),
18+
mindim=fill(1, nsweeps),
19+
noise=fill(0.0, nsweeps),
20+
kwargs...,
21+
)
3722
maxdim = _extend_sweeps_param(maxdim, nsweeps)
3823
mindim = _extend_sweeps_param(mindim, nsweeps)
3924
cutoff = _extend_sweeps_param(cutoff, nsweeps)
4025
noise = _extend_sweeps_param(noise, nsweeps)
41-
42-
return (; maxdim, mindim, cutoff, noise)
26+
return maxdim, mindim, cutoff, noise
4327
end
4428

45-
function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
46-
reverse_step = get(kwargs, :reverse_step, true)
47-
48-
nsweeps = _compute_nsweeps(t; kwargs...)
49-
maxdim, mindim, cutoff, noise = process_sweeps(; nsweeps, kwargs...)
50-
51-
time_start::Number = get(kwargs, :time_start, 0.0)
52-
time_step::Number = get(kwargs, :time_step, t)
53-
order = get(kwargs, :order, 2)
54-
tdvp_order = TDVPOrder(order, Base.Forward)
29+
function alternating_update(
30+
solver,
31+
PH,
32+
psi0::AbstractTTN;
33+
checkdone=nothing,
34+
tdvp_order=TDVPOrder(2, Base.Forward),
35+
outputlevel=0,
36+
time_start=0.0,
37+
time_step=0.0,
38+
nsweeps=1,
39+
write_when_maxdim_exceeds::Union{Int,Nothing}=nothing,
40+
kwargs...,
41+
)
42+
maxdim, mindim, cutoff, noise = process_sweeps(nsweeps; kwargs...)
5543

56-
checkdone = get(kwargs, :checkdone, nothing)
57-
write_when_maxdim_exceeds::Union{Int,Nothing} = get(
58-
kwargs, :write_when_maxdim_exceeds, nothing
59-
)
60-
observer = get(kwargs, :observer!, nothing)
6144
step_observer = get(kwargs, :step_observer!, nothing)
62-
outputlevel::Int = get(kwargs, :outputlevel, 0)
6345

6446
psi = copy(psi0)
6547

@@ -89,7 +71,7 @@ function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
8971
psi;
9072
kwargs...,
9173
current_time,
92-
reverse_step,
74+
outputlevel,
9375
sweep=sw,
9476
maxdim=maxdim[sw],
9577
mindim=mindim[sw],
@@ -121,22 +103,14 @@ function alternating_update(solver, PH, t::Number, psi0::AbstractTTN; kwargs...)
121103
return psi
122104
end
123105

124-
function alternating_update(solver, H::AbstractTTN, t::Number, psi0::AbstractTTN; kwargs...)
106+
function alternating_update(solver, H::AbstractTTN, psi0::AbstractTTN; kwargs...)
125107
check_hascommoninds(siteinds, H, psi0)
126108
check_hascommoninds(siteinds, H, psi0')
127109
# Permute the indices to have a better memory layout
128110
# and minimize permutations
129111
H = ITensors.permute(H, (linkind, siteinds, linkind))
130112
PH = ProjTTN(H)
131-
return alternating_update(solver, PH, t, psi0; kwargs...)
132-
end
133-
134-
function alternating_update(solver, t::Number, H, psi0::AbstractTTN; kwargs...)
135-
return alternating_update(solver, H, t, psi0; kwargs...)
136-
end
137-
138-
function alternating_update(solver, H, psi0::AbstractTTN, t::Number; kwargs...)
139-
return alternating_update(solver, H, t, psi0; kwargs...)
113+
return alternating_update(solver, PH, psi0; kwargs...)
140114
end
141115

142116
"""
@@ -158,14 +132,12 @@ each step of the algorithm when optimizing the MPS.
158132
Returns:
159133
* `psi::MPS` - time-evolved MPS
160134
"""
161-
function alternating_update(
162-
solver, Hs::Vector{<:AbstractTTN}, t::Number, psi0::AbstractTTN; kwargs...
163-
)
135+
function alternating_update(solver, Hs::Vector{<:AbstractTTN}, psi0::AbstractTTN; kwargs...)
164136
for H in Hs
165137
check_hascommoninds(siteinds, H, psi0)
166138
check_hascommoninds(siteinds, H, psi0')
167139
end
168140
Hs .= ITensors.permute.(Hs, Ref((linkind, siteinds, linkind)))
169141
PHs = ProjTTNSum(Hs)
170-
return alternating_update(solver, PHs, t, psi0; kwargs...)
142+
return alternating_update(solver, PHs, psi0; kwargs...)
171143
end

src/treetensornetworks/solvers/contract.jl

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
1-
function contract_solver(; kwargs...)
2-
function solver(PH, t, psi; kws...)
3-
v = ITensor(1.0)
4-
for j in sites(PH)
5-
v *= PH.psi0[j]
6-
end
7-
Hpsi0 = contract(PH, v)
8-
return Hpsi0, nothing
1+
function contract_solver(PH, psi; kwargs...)
2+
v = ITensor(1.0)
3+
for j in sites(PH)
4+
v *= PH.psi0[j]
95
end
10-
return solver
6+
Hpsi0 = contract(PH, v)
7+
return Hpsi0, nothing
118
end
129

1310
function contract(
@@ -44,12 +41,9 @@ function contract(
4441
## )
4542
## end
4643

47-
t = Inf
4844
reverse_step = false
4945
PH = ProjTTNApply(tn2, tn1)
50-
psi = alternating_update(
51-
contract_solver(; kwargs...), PH, t, init; nsweeps, reverse_step, kwargs...
52-
)
46+
psi = alternating_update(contract_solver, PH, init; nsweeps, reverse_step, kwargs...)
5347

5448
return psi
5549
end

src/treetensornetworks/solvers/dmrg.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
function eigsolve_solver(; kwargs...)
2-
function solver(H, t, init; kws...)
1+
function eigsolve_solver(; solver_which_eigenvalue=:SR, kwargs...)
2+
function solver(H, init; kws...)
33
howmany = 1
4-
which = get(kwargs, :solver_which_eigenvalue, :SR)
4+
which = solver_which_eigenvalue
55
solver_kwargs = (;
66
ishermitian=get(kwargs, :ishermitian, true),
77
tol=get(kwargs, :solver_tol, 1E-14),
@@ -20,11 +20,8 @@ end
2020
Overload of `ITensors.dmrg`.
2121
"""
2222
function dmrg(H, init::AbstractTTN; kwargs...)
23-
t = Inf # DMRG is TDVP with an infinite timestep and no reverse step
2423
reverse_step = false
25-
psi = alternating_update(
26-
eigsolve_solver(; kwargs...), H, t, init; reverse_step, kwargs...
27-
)
24+
psi = alternating_update(eigsolve_solver(; kwargs...), H, init; reverse_step, kwargs...)
2825
return psi
2926
end
3027

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function dmrg_x_solver(PH, t, init; kwargs...)
1+
function dmrg_x_solver(PH, init; kwargs...)
22
H = contract(PH, ITensor(1.0))
33
D, U = eigen(H; ishermitian=true)
44
u = uniqueind(U, H)
@@ -8,7 +8,6 @@ function dmrg_x_solver(PH, t, init; kwargs...)
88
end
99

1010
function dmrg_x(PH, init::AbstractTTN; reverse_step=false, kwargs...)
11-
t = Inf
12-
psi = alternating_update(dmrg_x_solver, PH, t, init; reverse_step, kwargs...)
11+
psi = alternating_update(dmrg_x_solver, PH, init; reverse_step, kwargs...)
1312
return psi
1413
end

src/treetensornetworks/solvers/linsolve.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ function linsolve(
2727
)
2828
function linsolve_solver(
2929
P,
30-
t,
3130
x₀;
3231
ishermitian=false,
3332
solver_tol=1E-14,
@@ -50,9 +49,8 @@ function linsolve(
5049

5150
error("`linsolve` for TTN not yet implemented.")
5251

53-
t = Inf
5452
# TODO: Define `itensornetwork_cache`
5553
# TODO: Define `linsolve_cache`
5654
P = linsolve_cache(itensornetwork_cache(x₀', A, x₀), itensornetwork_cache(x₀', b))
57-
return alternating_update(linsolve_solver, P, t, x₀; reverse_step=false, kwargs...)
55+
return alternating_update(linsolve_solver, P, x₀; reverse_step=false, kwargs...)
5856
end

src/treetensornetworks/solvers/tdvp.jl

Lines changed: 66 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,50 @@
11
function exponentiate_solver(; kwargs...)
2-
function solver(H, t, init; kws...)
2+
function solver(
3+
H,
4+
init;
5+
ishermitian=true,
6+
issymmetric=true,
7+
solver_krylovdim=30,
8+
solver_maxiter=100,
9+
solver_outputlevel=0,
10+
solver_tol=1E-12,
11+
substep,
12+
time_step,
13+
kws...,
14+
)
315
solver_kwargs = (;
4-
ishermitian=get(kwargs, :ishermitian, true),
5-
issymmetric=get(kwargs, :issymmetric, true),
6-
tol=get(kwargs, :solver_tol, 1E-12),
7-
krylovdim=get(kwargs, :solver_krylovdim, 30),
8-
maxiter=get(kwargs, :solver_maxiter, 100),
9-
verbosity=get(kwargs, :solver_outputlevel, 0),
16+
ishermitian,
17+
issymmetric,
18+
tol=solver_tol,
19+
krylovdim=solver_krylovdim,
20+
maxiter=solver_maxiter,
21+
verbosity=solver_outputlevel,
1022
eager=true,
1123
)
12-
psi, info = exponentiate(H, t, init; solver_kwargs...)
24+
25+
psi, info = KrylovKit.exponentiate(H, time_step, init; solver_kwargs...)
1326
return psi, info
1427
end
1528
return solver
1629
end
1730

1831
function applyexp_solver(; kwargs...)
19-
function solver(H, t, init; kws...)
20-
tol_per_unit_time = get(kwargs, :solver_tol, 1E-8)
21-
solver_kwargs = (;
22-
maxiter=get(kwargs, :solver_krylovdim, 30),
23-
outputlevel=get(kwargs, :solver_outputlevel, 0),
24-
)
32+
function solver(
33+
H,
34+
init;
35+
tdvp_order,
36+
solver_krylovdim=30,
37+
solver_outputlevel=0,
38+
solver_tol=1E-8,
39+
substep,
40+
time_step,
41+
kws...,
42+
)
43+
solver_kwargs = (; maxiter=solver_krylovdim, outputlevel=solver_outputlevel)
44+
2545
#applyexp tol is absolute, compute from tol_per_unit_time:
26-
tol = abs(t) * tol_per_unit_time
27-
psi, info = applyexp(H, t, init; tol, solver_kwargs..., kws...)
46+
tol = abs(time_step) * tol_per_unit_time
47+
psi, info = applyexp(H, time_step, init; tol, solver_kwargs..., kws...)
2848
return psi, info
2949
end
3050
return solver
@@ -42,22 +62,48 @@ function tdvp_solver(; solver_backend="exponentiate", kwargs...)
4262
end
4363
end
4464

45-
function tdvp(solver, H, t::Number, init::AbstractTTN; kwargs...)
46-
return alternating_update(solver, H, t, init; kwargs...)
65+
function _compute_nsweeps(nsteps, t, time_step)
66+
nsweeps = 1
67+
if !isnothing(nsteps) && time_step != t
68+
error("Cannot specify both nsteps and time_step in tdvp")
69+
elseif isfinite(time_step) && abs(time_step) > 0.0 && isnothing(nsteps)
70+
nsweeps = convert(Int, ceil(abs(t / time_step)))
71+
if !(nsweeps * time_step t)
72+
error("Time step $time_step not commensurate with total time t=$t")
73+
end
74+
end
75+
return nsweeps
76+
end
77+
78+
function tdvp(
79+
solver,
80+
H,
81+
t::Number,
82+
init::AbstractTTN;
83+
time_step::Number=t,
84+
nsteps=nothing,
85+
order::Integer=2,
86+
kwargs...,
87+
)
88+
nsweeps = _compute_nsweeps(nsteps, t, time_step)
89+
tdvp_order = TDVPOrder(order, Base.Forward)
90+
return alternating_update(solver, H, init; nsweeps, tdvp_order, time_step, kwargs...)
4791
end
4892

4993
"""
5094
tdvp(H::TTN, t::Number, psi0::TTN; kwargs...)
5195
5296
Use the time dependent variational principle (TDVP) algorithm
53-
to compute `exp(H*t)*psi0` using an efficient algorithm based
97+
to approximately compute `exp(H*t)*psi0` using an efficient algorithm based
5498
on alternating optimization of the state tensors and local Krylov
55-
exponentiation of H.
99+
exponentiation of H. The time parameter `t` can be a real or complex number.
56100
57101
Returns:
58102
* `psi` - time-evolved state
59103
60104
Optional keyword arguments:
105+
* `time_step::Number = t` - time step to use when evolving the state. Smaller time steps generally give more accurate results but can make the algorithm take more computational time to run.
106+
* `nsteps::Integer` - evolve by the requested total time `t` by performing `nsteps` of the TDVP algorithm. More steps can result in more accurate results but require more computational time to run. (Note that only one of the `time_step` or `nsteps` parameters can be provided, not both.)
61107
* `outputlevel::Int = 1` - larger outputlevel values resulting in printing more information and 0 means no output
62108
* `observer` - object implementing the [Observer](@ref observer) interface which can perform measurements and stop early
63109
* `write_when_maxdim_exceeds::Int` - when the allowed maxdim exceeds this value, begin saving tensors to disk to free memory in large calculations

src/treetensornetworks/solvers/tdvporder.jl

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
1+
12
struct TDVPOrder{order,direction} end
23

34
TDVPOrder(order::Int, direction::Base.Ordering) = TDVPOrder{order,direction}()
45

5-
orderings(::TDVPOrder) = error("Not implemented")
6+
directions(::TDVPOrder) = error("Not implemented")
67
sub_time_steps(::TDVPOrder) = error("Not implemented")
78

8-
function orderings(::TDVPOrder{1,direction}) where {direction}
9+
function directions(::TDVPOrder{1,direction}) where {direction}
910
return [direction, Base.ReverseOrdering(direction)]
1011
end
1112
sub_time_steps(::TDVPOrder{1}) = [1.0, 0.0]
1213

13-
function orderings(::TDVPOrder{2,direction}) where {direction}
14+
function directions(::TDVPOrder{2,direction}) where {direction}
1415
return [direction, Base.ReverseOrdering(direction)]
1516
end
1617
sub_time_steps(::TDVPOrder{2}) = [1.0 / 2.0, 1.0 / 2.0]
1718

18-
function orderings(::TDVPOrder{4,direction}) where {direction}
19+
#
20+
# TODO: possible bug, shouldn't length(directions) here equal
21+
# length(sub_time_steps) below? (I.e. both return a length 6 vector?)
22+
#
23+
function directions(::TDVPOrder{4,direction}) where {direction}
1924
return [direction, Base.ReverseOrdering(direction)]
2025
end
2126
function sub_time_steps(::TDVPOrder{4})

0 commit comments

Comments
 (0)