Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Belief Propagation Version 3 #85

Merged
merged 27 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2da1ae2
Added function for building ITensorNetwork from a single ITensor
JoeyT1994 Apr 10, 2023
c998839
Re-wrote BP Functions for having message tensors as general ITensorNe…
JoeyT1994 Apr 10, 2023
1db7fb2
Modified Example to work with new BP structure
JoeyT1994 Apr 10, 2023
1dbf2cd
Added Support for different output types of get_environment()
JoeyT1994 Apr 10, 2023
688d2d4
Updated tests
JoeyT1994 Apr 10, 2023
54b1a9c
Merge remote-tracking branch 'upstream/main' into ITensorNetworkBelie…
JoeyT1994 Apr 10, 2023
2e55db4
Added Test for Advanced BP
JoeyT1994 Apr 10, 2023
e859ffd
Updated Example to Reflect Changes
JoeyT1994 Apr 10, 2023
f1baeb3
Updated Contraction Function to Spit out a Network. Leave the contrac…
JoeyT1994 Apr 10, 2023
fdff9b1
Updated Gauging and Tests to incorporate new Belief Propagation Changes
JoeyT1994 Apr 11, 2023
ba3b2a0
Extended Belief Propagation Example to Include Boundary MPS
JoeyT1994 Apr 11, 2023
7109e0c
Updated Testing for Belief Propagation
JoeyT1994 Apr 11, 2023
417fa7a
Formatting
JoeyT1994 Apr 11, 2023
49792fa
Added better, more streamlined initialisation for message tensors
JoeyT1994 Apr 11, 2023
c98c05c
Formatting
JoeyT1994 Apr 11, 2023
9f02fcf
Edited examples and tests to avoid KaHyPar
JoeyT1994 Apr 12, 2023
18f2f56
Renamed Belief Propagation Functions
JoeyT1994 Apr 13, 2023
ca1c8f6
Formatting and more name changes
JoeyT1994 Apr 13, 2023
9d6bf0d
Changed output of get_environment() to ITensorNetwork only
JoeyT1994 Apr 13, 2023
6301116
Formatting and Added Option to Supply ITensorNetwork to Apply()
JoeyT1994 Apr 13, 2023
7dc783c
Function renaming and formatting
JoeyT1994 Apr 13, 2023
c2ecc7d
Removed Function. Defaulted to MPS Rank 1 as Message Tensor. Added Fl…
JoeyT1994 Apr 14, 2023
81b3aba
Added kwarg
JoeyT1994 Apr 14, 2023
a7baa54
Further Changes. Temp
JoeyT1994 Apr 17, 2023
7450474
Better Initialisation Routine for Message Tensors
JoeyT1994 Apr 17, 2023
1bf77fe
Further Improved Initialisation
JoeyT1994 Apr 17, 2023
4866b6a
Formatting
JoeyT1994 Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions examples/belief_propagation/bpexample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ using Random
using SplitApplyCombine

using ITensorNetworks:
compute_message_tensors, calculate_contraction, contract_inner, nested_graph_leaf_vertices
belief_propagation, approx_network_region, contract_inner, nested_graph_leaf_vertices

function main()
n = 4
dims = (n, n)
g = named_grid(dims)
g_dims = (n, n)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
chi = 2

Expand All @@ -31,11 +31,12 @@ function main()
vertex_groups = nested_graph_leaf_vertices(
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
mts = belief_propagation(ψψ; vertex_groups=vertex_groups)
numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
denominator_network = approx_network_region(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"Simple Belief Propagation Gives Sz on Site " * string(v) * " as " * string(sz_bp)
Expand All @@ -46,23 +47,62 @@ function main()
vertex_groups = nested_graph_leaf_vertices(
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
mts = belief_propagation(ψψ; vertex_groups=vertex_groups)
numerator_network = approx_network_region(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
denominator_network = approx_network_region(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation (2-site subgraphs) Gives Sz on Site " *
"General Belief Propagation (4-site subgraphs) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
#Now do General Belief Propagation with Matrix Product State Message Tensors Measure Sz on Site v
ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
Oψ = copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
sz_exact = contract_inner(Oψ, ψ) / contract_inner(ψ, ψ)
ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

combiners = linkinds_combiners(ψψ)
ψψ = combine_linkinds(ψψ, combiners)
ψOψ = combine_linkinds(ψOψ, combiners)

vertex_groups = nested_graph_leaf_vertices(partition(ψψ, group(v -> v[1], vertices(ψψ))))
maxdim = 8

mts = belief_propagation(
ψψ;
vertex_groups=vertex_groups,
contract_kwargs=(;
alg="density_matrix",
output_structure=path_graph_structure,
maxdim,
contraction_sequence_alg="optimal",
),
init_contract_kwargs=(;
alg="density_matrix",
output_structure=path_graph_structure,
cutoff=1e-16,
contraction_sequence_alg="optimal",
),
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
)
numerator_network = approx_network_region(ψψ, mts, [v]; verts_tn=ITensorNetwork(ψOψ[v]))
denominator_network = approx_network_region(ψψ, mts, [v])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation with Column Partitioning and MPS Message Tensors (Max dim 8) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
sz_exact = contract(ψOψ)[] / contract(ψψ)[]

return println("The exact value of Sz on Site " * string(v) * " is " * string(sz_exact))
end
Expand Down
1 change: 1 addition & 0 deletions src/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function ITensors.apply(
)
end

envs = Vector{ITensor}(envs)
if !isempty(envs)
extended_envs = vcat(envs, Qᵥ₁, prime(dag(Qᵥ₁)), Qᵥ₂, prime(dag(Qᵥ₂)))
Rᵥ₁, Rᵥ₂ = optimise_p_q(
Expand Down
184 changes: 108 additions & 76 deletions src/beliefpropagation.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,67 @@
function construct_initial_mts(
function message_tensors(
tn::ITensorNetwork, nvertices_per_partition::Integer; partition_kwargs=(;), kwargs...
)
return construct_initial_mts(
tn, partition(tn; nvertices_per_partition, partition_kwargs...); kwargs...
return message_tensors(
partition(tn; nvertices_per_partition, partition_kwargs...); kwargs...
)
end

function construct_initial_mts(
tn::ITensorNetwork, subgraphs::DataGraph; init=(I...) -> @compat allequal(I) ? 1 : 0
function message_tensors(
subgraphs::DataGraph; contract_kwargs=(;), init=(I...) -> allequal(I) ? 1 : 0
)
# TODO: This is dropping the vertex data for some reason.
# mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(subgraphs)
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensorNetwork}(
directed_graph(underlying_graph(subgraphs))
)
for v in vertices(mts)
mts[v] = subgraphs[v]
end
for subgraph in vertices(subgraphs)
tns_to_contract = ITensor[]
for subgraph_neighbor in neighbors(subgraphs, subgraph)
edge_inds = Index[]
for vertex in vertices(subgraphs[subgraph])
psiv = tn[vertex]
for e in [edgetype(tn)(vertex => neighbor) for neighbor in neighbors(tn, vertex)]
if (find_subgraph(dst(e), subgraphs) == subgraph_neighbor)
append!(edge_inds, commoninds(tn, e))
end
end
end
mt = normalize!(
itensor(
[init(Tuple(I)...) for I in CartesianIndices(tuple(dim.(edge_inds)...))],
edge_inds,
boundary_vertices_s = setdiff(
vertices(subgraphs[subgraph]),
unique(
flatten([
neighbors(subgraphs[subgraph_neighbor], vn) for
vn in vertices(subgraphs[subgraph_neighbor])
]),
),
)
boundary_vertices_sn = setdiff(
vertices(subgraphs[subgraph_neighbor]),
unique(
flatten([
neighbors(subgraphs[subgraph], vn) for vn in vertices(subgraphs[subgraph])
]),
),
)
mts[subgraph => subgraph_neighbor] = mt
boundary_tensors = ITensor[]
for v in boundary_vertices_s
inds_boundary = flatten(
unique([inds(subgraphs[subgraph_neighbor][vn]) for vn in boundary_vertices_sn])
)
inds_internal = flatten(
unique([inds(subgraphs[subgraph][v]) for v in setdiff(boundary_vertices_s, [v])])
)
new_inds = intersect(
inds(subgraphs[subgraph][v]), vcat(inds_boundary, inds_internal)
)
push!(
boundary_tensors,
itensor(
[init(Tuple(I)...) for I in CartesianIndices(tuple(dim.(new_inds)...))],
new_inds,
),
)
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

mt = ITensorNetwork(Dictionaries.Dictionary(boundary_vertices_s, boundary_tensors))

contract_output = contract(mt; contract_kwargs...)
mts[subgraph => subgraph_neighbor] = if typeof(contract_output) == ITensor
ITensorNetwork(contract_output)
else
first(contract_output)
end
end
end
return mts
Expand All @@ -44,60 +70,85 @@ end
"""
DO a single update of a message tensor using the current subgraph and the incoming mts
"""
function update_mt(
function update_message_tensor(
tn::ITensorNetwork,
subgraph_vertices::Vector,
mts::Vector{ITensor};
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
mts::Vector{ITensorNetwork};
contract_kwargs=(;),
)
contract_list = [mts; [tn[v] for v in subgraph_vertices]]
contract_list = ITensorNetwork[mts; ITensorNetwork([tn[v] for v in subgraph_vertices])]

new_mt = if isone(length(contract_list))
tn = if isone(length(contract_list))
copy(only(contract_list))
else
contract(contract_list; sequence=contraction_sequence(contract_list))
reduce(⊗, contract_list)
end
return normalize!(new_mt)

contract_output = contract(tn; contract_kwargs...)
itn = if typeof(contract_output) == ITensor
ITensorNetwork(contract_output)
else
first(contract_output)
end
normalize!.(vertex_data(itn))

return itn
end

function update_mt(
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensor}; kwargs...
function update_message_tensor(
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensorNetwork}; kwargs...
)
return update_mt(tn, vertices(subgraph), mts; kwargs...)
return update_message_tensor(tn, vertices(subgraph), mts; kwargs...)
end

"""
Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
"""
function update_all_mts(
tn::ITensorNetwork,
mts::DataGraph;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
function belief_propagation_iteration(
tn::ITensorNetwork, mts::DataGraph; contract_kwargs=(;)
)
update_mts = copy(mts)
new_mts = copy(mts)
for e in edges(mts)
environment_tensors = ITensor[
environment_tensornetworks = ITensorNetwork[
mts[e_in] for e_in in setdiff(boundary_edges(mts, src(e); dir=:in), [reverse(e)])
]
update_mts[src(e) => dst(e)] = update_mt(
tn, mts[src(e)], environment_tensors; contraction_sequence

new_mts[src(e) => dst(e)] = update_message_tensor(
tn, mts[src(e)], environment_tensornetworks; contract_kwargs
)
end
return update_mts
return new_mts
end

function update_all_mts(
tn::ITensorNetwork,
mts::DataGraph,
niters::Int;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
function belief_propagation(
tn::ITensorNetwork, mts::DataGraph, niters::Int; contract_kwargs=(;)
)
for i in 1:niters
mts = update_all_mts(tn, mts; contraction_sequence)
mts = belief_propagation_iteration(tn, mts; contract_kwargs)
end
return mts
end

"""
Simulaneously initialise and update message tensors of a tensornetwork
"""
function belief_propagation(
tn::ITensorNetwork;
niters=10,
nvertices_per_partition=nothing,
npartitions=nothing,
vertex_groups=nothing,
contract_kwargs=(;),
init_contract_kwargs=(;),
init_kwargs...,
)
Z = partition(tn; nvertices_per_partition, npartitions, subgraph_vertices=vertex_groups)

mts = message_tensors(Z; contract_kwargs=init_contract_kwargs, init_kwargs...)
mts = belief_propagation(tn, mts, niters; contract_kwargs)
return mts
end

"""
Given a subet of vertices of a given Tensor Network and the Message Tensors for that network, return a Dictionary with the involved subgraphs as keys and the vector of tensors associated with that subgraph as values
Specifically, the contraction of the environment tensors and tn[vertices] will be a scalar.
Expand All @@ -109,43 +160,24 @@ function get_environment(tn::ITensorNetwork, mts::DataGraph, verts::Vector; dir=
return get_environment(tn, mts, setdiff(vertices(tn), verts))
end

env_tensors = ITensor[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
return vcat(
env_tensors,
ITensor[tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)],
)
env_tns = ITensorNetwork[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
central_tn = ITensorNetwork([
tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)
])
return ITensorNetwork(vcat(env_tns, ITensorNetwork[central_tn]))
end

"""
Calculate the contraction of a tensor network centred on the vertices verts. Using message tensors.
Defaults to using tn[verts] as the local network but can be overriden
"""
function calculate_contraction(
function approx_network_region(
tn::ITensorNetwork,
mts::DataGraph,
verts::Vector;
verts_tensors=ITensor[tn[v] for v in verts],
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
verts_tn=ITensorNetwork([tn[v] for v in verts]),
)
environment_tensors = get_environment(tn, mts, verts)
tensors_to_contract = vcat(environment_tensors, verts_tensors)
return contract(tensors_to_contract; sequence=contraction_sequence(tensors_to_contract))
end
environment_tn = get_environment(tn, mts, verts)

"""
Simulaneously initialise and update message tensors of a tensornetwork
"""
function compute_message_tensors(
tn::ITensorNetwork;
niters=10,
nvertices_per_partition=nothing,
npartitions=nothing,
vertex_groups=nothing,
kwargs...,
)
Z = partition(tn; nvertices_per_partition, npartitions, subgraph_vertices=vertex_groups)

mts = construct_initial_mts(tn, Z; kwargs...)
mts = update_all_mts(tn, mts, niters)
return mts
return environment_tn ⊗ verts_tn
end
Loading