Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Belief Propagation Version 3 #85

Merged
merged 27 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
2da1ae2
Added function for building ITensorNetwork from a single ITensor
JoeyT1994 Apr 10, 2023
c998839
Re-wrote BP Functions for having message tensors as general ITensorNe…
JoeyT1994 Apr 10, 2023
1db7fb2
Modified Example to work with new BP structure
JoeyT1994 Apr 10, 2023
1dbf2cd
Added Support for different output types of get_environment()
JoeyT1994 Apr 10, 2023
688d2d4
Updated tests
JoeyT1994 Apr 10, 2023
54b1a9c
Merge remote-tracking branch 'upstream/main' into ITensorNetworkBelie…
JoeyT1994 Apr 10, 2023
2e55db4
Added Test for Advanced BP
JoeyT1994 Apr 10, 2023
e859ffd
Updated Example to Reflect Changes
JoeyT1994 Apr 10, 2023
f1baeb3
Updated Contraction Function to Spit out a Network. Leave the contrac…
JoeyT1994 Apr 10, 2023
fdff9b1
Updated Gauging and Tests to incorporate new Belief Propagation Changes
JoeyT1994 Apr 11, 2023
ba3b2a0
Extended Belief Propagation Example to Include Boundary MPS
JoeyT1994 Apr 11, 2023
7109e0c
Updated Testing for Belief Propagation
JoeyT1994 Apr 11, 2023
417fa7a
Formatting
JoeyT1994 Apr 11, 2023
49792fa
Added better, more streamlined initialisation for message tensors
JoeyT1994 Apr 11, 2023
c98c05c
Formatting
JoeyT1994 Apr 11, 2023
9f02fcf
Edited examples and tests to avoid KaHyPar
JoeyT1994 Apr 12, 2023
18f2f56
Renamed Belief Propagation Functions
JoeyT1994 Apr 13, 2023
ca1c8f6
Formatting and more name changes
JoeyT1994 Apr 13, 2023
9d6bf0d
Changed output of get_environment() to ITensorNetwork only
JoeyT1994 Apr 13, 2023
6301116
Formatting and Added Option to Supply ITensorNetwork to Apply()
JoeyT1994 Apr 13, 2023
7dc783c
Function renaming and formatting
JoeyT1994 Apr 13, 2023
c2ecc7d
Removed Function. Defaulted to MPS Rank 1 as Message Tensor. Added Fl…
JoeyT1994 Apr 14, 2023
81b3aba
Added kwarg
JoeyT1994 Apr 14, 2023
a7baa54
Further Changes. Temp
JoeyT1994 Apr 17, 2023
7450474
Better Initialisation Routine for Message Tensors
JoeyT1994 Apr 17, 2023
1bf77fe
Further Improved Initialisation
JoeyT1994 Apr 17, 2023
4866b6a
Formatting
JoeyT1994 Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions examples/belief_propagation/bpexample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ using Random
using SplitApplyCombine

using ITensorNetworks:
compute_message_tensors, calculate_contraction, contract_inner, nested_graph_leaf_vertices
compute_message_tensors,
calculate_contraction_network,
contract_inner,
nested_graph_leaf_vertices

function main()
n = 4
dims = (n, n)
g = named_grid(dims)
g_dims = (n, n)
g = named_grid(g_dims)
s = siteinds("S=1/2", g)
chi = 2

Expand All @@ -32,10 +35,11 @@ function main()
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
numerator_network = calculate_contraction_network(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
denominator_network = calculate_contraction_network(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"Simple Belief Propagation Gives Sz on Site " * string(v) * " as " * string(sz_bp)
Expand All @@ -47,22 +51,63 @@ function main()
partition(partition(ψψ, group(v -> v[1], vertices(ψψ))); nvertices_per_partition=nsites)
)
mts = compute_message_tensors(ψψ; vertex_groups=vertex_groups)
sz_bp =
calculate_contraction(
ψψ, mts, [(v, 1)]; verts_tensors=ITensor[apply(op("Sz", s[v]), ψ[v])]
)[] / calculate_contraction(ψψ, mts, [(v, 1)])[]
numerator_network = calculate_contraction_network(
ψψ, mts, [(v, 1)]; verts_tn=ITensorNetwork([apply(op("Sz", s[v]), ψ[v])])
)
denominator_network = calculate_contraction_network(ψψ, mts, [(v, 1)])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation (2-site subgraphs) Gives Sz on Site " *
"General Belief Propagation (4-site subgraphs) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
#Now do General Belief Propagation with Matrix Product State Message Tensors Measure Sz on Site v
ψψ = flatten_networks(ψ, dag(ψ); combine_linkinds=false, map_bra_linkinds=prime)
Oψ = copy(ψ)
Oψ[v] = apply(op("Sz", s[v]), ψ[v])
sz_exact = contract_inner(Oψ, ψ) / contract_inner(ψ, ψ)
ψOψ = flatten_networks(ψ, dag(Oψ); combine_linkinds=false, map_bra_linkinds=prime)

combiners = linkinds_combiners(ψψ)
ψψ = combine_linkinds(ψψ, combiners)
ψOψ = combine_linkinds(ψOψ, combiners)

vertex_groups = nested_graph_leaf_vertices(partition(ψψ, group(v -> v[1], vertices(ψψ))))
maxdim = 8

mts = compute_message_tensors(
ψψ;
vertex_groups=vertex_groups,
contract_kwargs=(;
alg="density_matrix",
output_structure=path_graph_structure,
maxdim,
contraction_sequence_alg="optimal",
),
init_contract_kwargs=(;
alg="density_matrix",
output_structure=path_graph_structure,
cutoff=1e-16,
contraction_sequence_alg="optimal",
),
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
)
numerator_network = calculate_contraction_network(
ψψ, mts, [v]; verts_tn=ITensorNetwork(ψOψ[v])
)
denominator_network = calculate_contraction_network(ψψ, mts, [v])
sz_bp = contract(numerator_network)[] / contract(denominator_network)[]

println(
"General Belief Propagation with Column Partitioning and MPS Message Tensors (Max dim 8) Gives Sz on Site " *
string(v) *
" as " *
string(sz_bp),
)

#Now do it exactly
sz_exact = contract(ψOψ)[] / contract(ψψ)[]

return println("The exact value of Sz on Site " * string(v) * " is " * string(sz_exact))
end
Expand Down
152 changes: 100 additions & 52 deletions src/beliefpropagation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,60 @@ function construct_initial_mts(
end

function construct_initial_mts(
tn::ITensorNetwork, subgraphs::DataGraph; init=(I...) -> @compat allequal(I) ? 1 : 0
tn::ITensorNetwork,
subgraphs::DataGraph;
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
contract_kwargs=(;),
init=(I...) -> allequal(I) ? 1 : 0,
)
# TODO: This is dropping the vertex data for some reason.
# mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(subgraphs)
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensor}(
mts = DataGraph{vertextype(subgraphs),vertex_data_type(subgraphs),ITensorNetwork}(
directed_graph(underlying_graph(subgraphs))
)
for v in vertices(mts)
mts[v] = subgraphs[v]
end
for subgraph in vertices(subgraphs)
tns_to_contract = ITensor[]
for subgraph_neighbor in neighbors(subgraphs, subgraph)
edge_inds = Index[]
for vertex in vertices(subgraphs[subgraph])
psiv = tn[vertex]
for e in [edgetype(tn)(vertex => neighbor) for neighbor in neighbors(tn, vertex)]
if (find_subgraph(dst(e), subgraphs) == subgraph_neighbor)
append!(edge_inds, commoninds(tn, e))
end
end
end
mt = normalize!(
itensor(
[init(Tuple(I)...) for I in CartesianIndices(tuple(dim.(edge_inds)...))],
edge_inds,
boundary_vertices_s = setdiff(
vertices(subgraphs[subgraph]),
unique(
flatten([
neighbors(subgraphs[subgraph_neighbor], vn) for
vn in vertices(subgraphs[subgraph_neighbor])
]),
),
)
boundary_vertices_sn = setdiff(
vertices(subgraphs[subgraph_neighbor]),
unique(
flatten([
neighbors(subgraphs[subgraph], vn) for vn in vertices(subgraphs[subgraph])
]),
),
)
mts[subgraph => subgraph_neighbor] = mt
boundary_tensors = ITensor[]
for v in boundary_vertices_s
inds_boundary = flatten(unique([inds(tn[vn]) for vn in boundary_vertices_sn]))
inds_internal = flatten(
unique([inds(tn[v]) for v in setdiff(boundary_vertices_s, [v])])
)
new_inds = intersect(inds(tn[v]), vcat(inds_boundary, inds_internal))
push!(
boundary_tensors,
itensor(
[init(Tuple(I)...) for I in CartesianIndices(tuple(dim.(new_inds)...))],
new_inds,
),
)
end

mt = ITensorNetwork(Dictionaries.Dictionary(boundary_vertices_s, boundary_tensors))

contract_output = contract(mt; contract_kwargs...)
mts[subgraph => subgraph_neighbor] = if typeof(contract_output) == ITensor
ITensorNetwork(contract_output)
else
first(contract_output)
end
end
end
return mts
Expand All @@ -47,53 +72,56 @@ DO a single update of a message tensor using the current subgraph and the incomi
function update_mt(
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
tn::ITensorNetwork,
subgraph_vertices::Vector,
mts::Vector{ITensor};
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
mts::Vector{ITensorNetwork};
contract_kwargs=(;),
)
contract_list = [mts; [tn[v] for v in subgraph_vertices]]
contract_list = ITensorNetwork[mts; ITensorNetwork([tn[v] for v in subgraph_vertices])]

new_mt = if isone(length(contract_list))
tn = if isone(length(contract_list))
copy(only(contract_list))
else
contract(contract_list; sequence=contraction_sequence(contract_list))
reduce(⊗, contract_list)
end

contract_output = contract(tn; contract_kwargs...)
itn = if typeof(contract_output) == ITensor
ITensorNetwork(contract_output)
else
first(contract_output)
end
return normalize!(new_mt)
normalize!.(vertex_data(itn))

return itn
end

function update_mt(
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensor}; kwargs...
tn::ITensorNetwork, subgraph::ITensorNetwork, mts::Vector{ITensorNetwork}; kwargs...
)
return update_mt(tn, vertices(subgraph), mts; kwargs...)
end

"""
Do an update of all message tensors for a given ITensornetwork and its partition into sub graphs
"""
function update_all_mts(
tn::ITensorNetwork,
mts::DataGraph;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
)
function update_all_mts(tn::ITensorNetwork, mts::DataGraph; contract_kwargs=(;))
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
update_mts = copy(mts)
for e in edges(mts)
environment_tensors = ITensor[
environment_tensornetworks = ITensorNetwork[
mts[e_in] for e_in in setdiff(boundary_edges(mts, src(e); dir=:in), [reverse(e)])
]

update_mts[src(e) => dst(e)] = update_mt(
tn, mts[src(e)], environment_tensors; contraction_sequence
tn, mts[src(e)], environment_tensornetworks; contract_kwargs
)
end
return update_mts
end

function update_all_mts(
tn::ITensorNetwork,
mts::DataGraph,
niters::Int;
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
tn::ITensorNetwork, mts::DataGraph, niters::Int; contract_kwargs=(;)
)
for i in 1:niters
mts = update_all_mts(tn, mts; contraction_sequence)
mts = update_all_mts(tn, mts; contract_kwargs)
end
return mts
end
Expand All @@ -109,27 +137,45 @@ function get_environment(tn::ITensorNetwork, mts::DataGraph, verts::Vector; dir=
return get_environment(tn, mts, setdiff(vertices(tn), verts))
end

env_tensors = ITensor[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
return vcat(
env_tensors,
ITensor[tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)],
)
env_tns = ITensorNetwork[mts[e] for e in boundary_edges(mts, subgraphs; dir=:in)]
central_tn = ITensorNetwork([
tn[v] for v in setdiff(flatten([vertices(mts[s]) for s in subgraphs]), verts)
])
return vcat(env_tns, ITensorNetwork[central_tn])
end

function get_environment(
output_type::Type, tn::ITensorNetwork, mts::DataGraph, verts::Vector; kwargs...
)
itns = get_environment(tn::ITensorNetwork, mts::DataGraph, verts::Vector; kwargs...)

if output_type == Vector{ITensorNetwork}
return itns
else
itn = reduce(⊗, itns)
if output_type == ITensorNetwork
return itn
elseif output_type == Vector{ITensor}
return ITensor[itn[v] for v in vertices(itn)]
else
error("Output Type for get_environment not Supported!")
end
end
end
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved

"""
Calculate the contraction of a tensor network centred on the vertices verts. Using message tensors.
Defaults to using tn[verts] as the local network but can be overriden
"""
function calculate_contraction(
function calculate_contraction_network(
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
tn::ITensorNetwork,
mts::DataGraph,
verts::Vector;
verts_tensors=ITensor[tn[v] for v in verts],
contraction_sequence::Function=tn -> contraction_sequence(tn; alg="optimal"),
verts_tn=ITensorNetwork([tn[v] for v in verts]),
)
environment_tensors = get_environment(tn, mts, verts)
tensors_to_contract = vcat(environment_tensors, verts_tensors)
return contract(tensors_to_contract; sequence=contraction_sequence(tensors_to_contract))
environment_tns = get_environment(tn, mts, verts)

return reduce(⊗, vcat(environment_tns, ITensorNetwork[verts_tn]))
end

"""
Expand All @@ -141,11 +187,13 @@ function compute_message_tensors(
nvertices_per_partition=nothing,
npartitions=nothing,
vertex_groups=nothing,
kwargs...,
contract_kwargs=(;),
init_contract_kwargs=(;),
init_kwargs...,
)
Z = partition(tn; nvertices_per_partition, npartitions, subgraph_vertices=vertex_groups)

mts = construct_initial_mts(tn, Z; kwargs...)
mts = update_all_mts(tn, mts, niters)
mts = construct_initial_mts(tn, Z; contract_kwargs=init_contract_kwargs, init_kwargs...)
mts = update_all_mts(tn, mts, niters; contract_kwargs)
return mts
end
12 changes: 8 additions & 4 deletions src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@ function symmetric_gauge(
vsrc, vdst = src(e), dst(e)

s1, s2 = find_subgraph((vsrc, 1), mts), find_subgraph((vdst, 1), mts)
edge_ind = commoninds(mts[s1 => s2], ψsymm[vsrc])

forward_mt = mts[s1 => s2][first(vertices(mts[s1 => s2]))]
backward_mt = mts[s2 => s1][first(vertices(mts[s2 => s1]))]

edge_ind = commoninds(forward_mt, ψsymm[vsrc])
edge_ind_sim = sim(edge_ind)

X_D, X_U = eigen(mts[s1 => s2]; ishermitian=true, cutoff=eigen_message_tensor_cutoff)
Y_D, Y_U = eigen(mts[s2 => s1]; ishermitian=true, cutoff=eigen_message_tensor_cutoff)
X_D, X_U = eigen(forward_mt; ishermitian=true, cutoff=eigen_message_tensor_cutoff)
Y_D, Y_U = eigen(backward_mt; ishermitian=true, cutoff=eigen_message_tensor_cutoff)
X_D, Y_D = map_diag(x -> x + regularization, X_D),
map_diag(x -> x + regularization, Y_D)

Expand All @@ -46,7 +50,7 @@ function symmetric_gauge(
S = replaceinds(
S, [commoninds(S, U)..., commoninds(S, V)...] => [edge_ind..., prime(edge_ind)...]
)
symm_mts[s1 => s2], symm_mts[s2 => s1] = S, S
symm_mts[s1 => s2], symm_mts[s2 => s1] = ITensorNetwork(S), ITensorNetwork(S)
end

return ψsymm, symm_mts
Expand Down
5 changes: 5 additions & 0 deletions src/itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ function ITensorNetwork{V}(ts::ITensorCollection) where {V}
return tn
end

function ITensorNetwork(t::ITensor)
ts = ITensor[t]
return ITensorNetwork{keytype(ts)}(ts)
end

#
# Construction from underyling named graph
#
Expand Down
Loading