Skip to content

MaximiseBilinearForm #219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ include("caches/abstractbeliefpropagationcache.jl")
include("caches/beliefpropagationcache.jl")
include("formnetworks/abstractformnetwork.jl")
include("formnetworks/bilinearformnetwork.jl")
include("formnetworks/maximise_bilinearformnetwork.jl")
include("formnetworks/quadraticformnetwork.jl")
include("contraction_tree_to_graph.jl")
include("gauging.jl")
Expand Down
133 changes: 133 additions & 0 deletions src/formnetworks/maximise_bilinearformnetwork.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using NamedGraphs.NamedGraphGenerators: named_grid
using NamedGraphs: NamedEdge
using ITensors: ITensors, ITensor, contract, dag
using Graphs: is_tree
using NamedGraphs.PartitionedGraphs: partitioned_graph, partitionedges, partitionvertex, partitionvertices
using NamedGraphs.GraphsExtensions: bfs_tree, leaf_vertices, post_order_dfs_edges, src, dst, vertices
using NDTensors: Algorithm
using Dictionaries
using LinearAlgebra: norm_sqr

default_solver_algorithm() = "orthogonalize"
default_solver_kwargs() = (; niters = 25, nsites = 1, tolerance = 1e-10, normalize = true, maxdim = nothing, cutoff = nothing)

#TODO: Come up with reasonable sequence for non-trees
function blf_update_sequence(g::AbstractGraph; nsites::Int64=1)
@assert is_tree(g)
if nsites == 1 || nsites == 2
es = post_order_dfs_edges(g, first(leaf_vertices(g)))
vs = [[src(e), dst(e)] for e in es]
regions = nsites == 2 ? vs : [[v] for v in unique(reduce(vcat, vs))]
return vcat(regions, reverse(reverse.(regions)))
else
error("Nsites > 2 sequences not currently supported")
end
end

#TODO: biorthogonal updater and gauging
function blf_updater(alg::Algorithm"orthogonalize", xAy_bpc::AbstractBeliefPropagationCache, y::AbstractITensorNetwork, prev_region::Vector, region::Vector)
path = gauge_path(y, prev_region, region)
y = gauge_walk(alg, y, path)
verts = unique(vcat(src.(path), dst.(path)))
factors = [dag(y[v]) for v in verts]
xAy_bpc = update_factors(xAy_bpc, Dictionary([(v, "ket") for v in verts], factors))
pe_path = partitionedges(partitioned_tensornetwork(xAy_bpc), [NamedEdge((src(e), "ket") => (dst(e), "ket")) for e in path])
xAy_bpc = update(Algorithm("bp"), xAy_bpc, pe_path; message_update_function_kwargs = (; normalize = false))
return xAy_bpc, y
end

function blf_extracter(xAy_bpc::AbstractBeliefPropagationCache, region::Vector)
return environment(xAy_bpc, [(v, "ket") for v in region])
end

function blf_inserter(∂xAy_bpc_∂r::Vector{ITensor}, xAy_bpc::AbstractBeliefPropagationCache, y::AbstractITensorNetwork, region::Vector; normalize, maxdim, cutoff)
yr = contract(∂xAy_bpc_∂r; sequence = "automatic")
if length(region) == 1
v = only(region)
if normalize
yr /= sqrt(norm_sqr(yr))
end
y[v] = yr
elseif length(region) == 2
v1, v2 = first(region), last(region)
linds, cind = uniqueinds(y[v1], y[v2]), commonind(y[v1], y[v2])
yv1, yv2 = factorize(yr, linds; ortho = "left", tags=tags(cind), cutoff, maxdim)
if normalize
yv2 /= sqrt(norm_sqr(yv2))
end
y[v1], y[v2] = yv1, yv2
else
error("Updates with regions bigger than 2 not currently supported")
end
vertices = [(v, "ket") for v in region]
factors = [y[v] for v in region]
xAy_bpc = update_factors(xAy_bpc, Dictionary(vertices, factors))
return y, xAy_bpc
end

function blf_costfunction(xAy::AbstractBeliefPropagationCache, region)
verts = [(v, "ket") for v in region]
return contract([environment(xAy, verts); factors(xAy, verts)]; sequence = "automatic")[]
end

#Optimize over y to maximize <x|A|y> * <y|dag(A)|x> / <y|y> based on a designated partitioning of the bilinearform
#For now, y should be a tree tensor network and <x|A|y> should be a tree under the partitioning
function maximize_bilinearform(
alg::Algorithm"orthogonalize",
xAy::BilinearFormNetwork,
y::ITensorNetwork = dag(ket_network(xAy)),
partition_verts = group(v -> first(v), vertices(xAy));
updater = blf_updater,
extracter = blf_extracter,
inserter = blf_inserter,
costfunction = blf_costfunction,
sequence = blf_update_sequence,
normalize::Bool = true,
niters::Int64 = 25,
nsites::Int64 = 1,
tolerance = nothing,
maxdim = nothing,
cutoff = nothing)

#These assertions can easily be lessened in the future
@assert is_tree(y)
xAy_bpc = BeliefPropagationCache(xAy, partition_verts)
@assert is_tree(partitioned_graph(xAy_bpc))
seq = sequence(y; nsites)

prev_region = collect(vertices(y))
cs = zeros(ComplexF64, (niters, length(seq)))
for i in 1:niters
for (j, region) in enumerate(seq)
xAy_bpc, y = updater(alg, xAy_bpc, y, prev_region, region)
∂xAy_bpc_∂r = extracter(xAy_bpc, region)
y, xAy_bpc = inserter(∂xAy_bpc_∂r, xAy_bpc, y, region; normalize, maxdim, cutoff)
cs[i, j] = costfunction(xAy_bpc, region)
prev_region = region
end
if i >= 2 && (abs(sum(cs[i, :]) - sum(cs[i-1, :]))) / length(seq) <= tolerance
return xAy_bpc, dag(y)
end
end

return xAy_bpc, dag(y)
end

function Base.truncate(x::AbstractITensorNetwork; maxdim_init::Int64, kwargs...)
y = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space = maxdim_init)
xIy = BilinearFormNetwork(x, y)
xIy_bpc, y_out = maximize_bilinearform(xIy, y; kwargs...)
return y_out
end

function ITensors.apply(A::AbstractITensorNetwork, x::AbstractITensorNetwork; maxdim_init::Int64, kwargs...)
y = ITensorNetwork(v -> inds -> delta(inds), siteinds(x); link_space = maxdim_init)
xAy = BilinearFormNetwork(A, x, y)
xAy_bpc, y_out = maximize_bilinearform(xAy, y; kwargs...)
return y_out
end

function maximize_bilinearform(xAy::BilinearFormNetwork, args...; alg = default_solver_algorithm(), solver_kwargs = default_solver_kwargs())
return maximize_bilinearform(Algorithm(alg), xAy, args...; solver_kwargs...)
end

23 changes: 23 additions & 0 deletions src/formnetworks/maximise_bilinearformnetwork_V2.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
@kwdef mutable struct FittingProblem{State, OverlapNetwork}
state::State
overlapnetwork::OverlapNetwork
squared_scalar::Number = 0
end

squared_scalar(F::FittingProblem) = F.squared_scalar
state(F::FittingProblem) = F.state
overlapnetwork(F::FittingProblem) = F.overlapnetwork

function set(F::FittingProblem; state = state(F), overlapnetwork = overlapnetwork(F), squared_scalar = squared_scalar(F))
return FittingProblem(; state, linearformnetwork, squared_scalar)
end

function fit_tensornetwork(tn::AbstractITensorNetwork, init_state::AbstractITensorNetwork, vertex_partitioning)
overlap_bpc = BeliefPropagationCache(inner_network(tn, init_state), vertex_partitioning)
init_prob = FittingProblem(; state = copy(init_state), overlapnetwork = overlap_bpc)
common_sweep_kwargs = (; nsites, outputlevel, updater_kwargs, inserter_kwargs)
kwargs_array = [(; common_sweep_kwargs..., sweep = s) for s in 1:nsweeps]
sweep_iter = sweep_iterator(init_prob, kwargs_array)
converged_prob = alternating_update(sweep_iter; outputlevel, kws...)
return squared_scalar(converged_prob), state(converged_prob)
end
50 changes: 50 additions & 0 deletions test/test_maximisebilinearform.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@eval module $(gensym())
using ITensorNetworks: BilinearFormNetwork, ITensorNetwork, random_tensornetwork, siteinds, subgraph, ttn, inner, truncate, maximize_bilinearform, union_all_inds
using ITensorNetworks.ModelHamiltonians: heisenberg
using Graphs: vertices
using NamedGraphs.NamedGraphGenerators: named_grid, named_comb_tree
using SplitApplyCombine: group
using StableRNGs: StableRNG
using TensorOperations: TensorOperations
using Test: @test, @test_broken, @testset
using ITensors: apply, dag, delta, prime


@testset "Maximise BilinearForm" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
)
begin

rng = StableRNG(1234)

g = named_comb_tree((3,2))
s = siteinds("S=1/2", g)

#One-site truncation
a = random_tensornetwork(rng, elt, s; link_space = 3)
b = truncate(a; maxdim_init = 3)
f = inner(a, b; alg = "exact") / sqrt(inner(a, a; alg = "exact") * inner(b, b; alg = "exact"))
@test f * conj(f) ≈ 1.0 atol = 10*eps(real(elt))

#Two-site truncation
a = random_tensornetwork(rng, elt, s; link_space = 3)
b = truncate(a; maxdim_init = 1, solver_kwargs= (; maxdim = 3, cutoff = 1e-16, nsites = 2, tolerance = 1e-8))
f = inner(a, b; alg = "exact") / sqrt(inner(a, a; alg = "exact") * inner(b, b; alg = "exact"))
@test f * conj(f) ≈ 1.0 atol = 10*eps(real(elt))

#One-site apply (no normalization)
a = random_tensornetwork(rng, elt, s; link_space = 2)
H = ITensorNetwork(ttn(heisenberg(g), s))
Ha = apply(H, a; maxdim_init = 4, solver_kwargs = (; niters = 20, nsites = 1, tolerance = 1e-8, normalize = false))
@test inner(Ha, a; alg = "exact") / inner(a, H, a; alg = "exact") ≈ 1.0 atol = 10*eps(real(elt))

#Two-site apply (no normalization)
a = random_tensornetwork(rng, elt, s; link_space = 2)
H = ITensorNetwork(ttn(heisenberg(g), s))
Ha = apply(H, a; maxdim_init = 1, solver_kwargs= (; maxdim = 4, cutoff = 1e-16, nsites = 2, tolerance = 1e-8, normalize = false))
@test inner(Ha, a; alg = "exact") / inner(a, H, a; alg = "exact") ≈ 1.0 atol = 10*eps(real(elt))

end
end

end
Loading