Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Contraction tree to graph #53

Merged
merged 11 commits into from
Jan 25, 2023
40 changes: 40 additions & 0 deletions src/Graphs/abstractgraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Determine if an edge involves a leaf (at src or dst)"""
function is_leaf_edge(g::AbstractGraph, e)
return is_leaf(g, src(e)) || is_leaf(g, dst(e))
end

"""Determine if a node has any neighbors which are leaves"""
function has_leaf_neighbor(g::AbstractGraph, v)
for vn in neighbors(g, v)
if(is_leaf(g, vn))
return true
end
end
return false
end

"""Get all edges which do not involve a leaf"""
function internal_edges(g::AbstractGraph)
return filter(e -> !is_leaf_edge(g, e), edges(g))
end

"""Get all vertices which are leaves of a graph"""
function leaf_vertices(g::AbstractGraph)
return vertices(g)[findall(==(1), [is_leaf(g,v) for v in vertices(g)])]
end

"""Get distance of a vertex from a leaf"""
function distance_to_leaf(g::AbstractGraph, v)
leaves = leaf_vertices(g)
if(isempty(leaves))
println("ERROR: GRAPH DOES NTO CONTAIN LEAVES")
return NaN
end

return minimum([length(a_star(g, v, leaf)) for leaf in leaves])
end

"""Return all vertices which are within a certain pathlength `dist` of the leaves of the graph"""
function distance_from_roots(g::AbstractGraph, dist::Int64)
return vertices(g)[findall(<=(dist), [distance_to_leaf(g, v) for v in vertices(g)])]
end
2 changes: 2 additions & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include("beliefpropagation.jl")
include("contraction_tree_to_graph.jl")
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand Down
83 changes: 83 additions & 0 deletions src/contraction_tree_to_graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@

"""The main object here is `g' a NamedGraph which represents a graphical version of a contraction sequence.
It's vertices describe a partition between the leaves of the sequence (will be labelled with an n = 1 or n = 3 element tuple, where each element of the tuple describes the leaves in one of those partition)
n = 1 implies the vertex is actually a leaf.
Edges connect vertices which are child/ parent and also define a bi-partition"""


"""Function to take a sequence (returned by ITensorNetworks.contraction_sequence) and construct a graph g which represents it (see above)"""
function contraction_sequence_to_graph(contract_sequence)

g = fill_contraction_sequence_graph_vertices(contract_sequence)

#Now we have the vertices we need to figure out the edges
for v in vertices(g)
#Only add edges from a parent (which defines a tripartition and thus has length 3) to its children
if(length(v) == 3)
#Work out which vertices it connects to
concat1, concat2, concat3 =[v[1]..., v[2]...], [v[2]..., v[3]...], [v[1]..., v[3]...]
for vn in setdiff(vertices(g), [v])
vn_set = [Set(vni) for vni in vn]
if(Set(concat1) ∈ vn_set || Set(concat2) ∈ vn_set || Set(concat3) ∈ vn_set)
add_edge!(g, v => vn)
end
end
end
end


return g
end


function fill_contraction_sequence_graph_vertices(contract_sequence)
g = NamedGraph()
leaves = collect(Leaves(contract_sequence))
fill_contraction_sequence_graph_vertices!(g, contract_sequence[1], leaves)
fill_contraction_sequence_graph_vertices!(g, contract_sequence[2], leaves)
return g
end

"""Given a contraction sequence which is a subsequence of some larger sequence which is being built on current_g and has leaves `leaves`
Spawn `contract sequence' as a vertex on `current_g' and continue on with its children """
function fill_contraction_sequence_graph_vertices!(g, contract_sequence, leaves)
if(isa(contract_sequence, Array))
group1 = collect(Leaves(contract_sequence[1]))
group2 = collect(Leaves(contract_sequence[2]))
remaining_verts = setdiff(leaves, vcat(group1, group2))
add_vertex!(g, (group1, group2, remaining_verts))
fill_contraction_sequence_graph_vertices!(g, contract_sequence[1], leaves)
fill_contraction_sequence_graph_vertices!(g, contract_sequence[2], leaves)
else
add_vertex!(g, ([contract_sequence], setdiff(leaves, [contract_sequence])))
end
end

"""Utility functions for the graphical representation of a contraction sequence"""

"""Get the vertex bi-partition that a given edge between non-leaf nodes represents"""
function contraction_tree_leaf_bipartition(g::AbstractGraph, e)

if(is_leaf_edge(g, e))
println("ERROR: EITHER THE SOURCE OR THE VERTEX IS A LEAF SO EDGE DOESN'T REALLY REPRESENT A BI-PARTITION")
JoeyT1994 marked this conversation as resolved.
Show resolved Hide resolved
end

vsrc_set, vdst_set = [Set(vni) for vni in src(e)], [Set(vni) for vni in dst(e)]
c1, c2, c3 = [src(e)[1]..., src(e)[2]...], [src(e)[2]..., src(e)[3]...], [src(e)[1]..., src(e)[3]...]
left_bipartition = Set(c1) ∈ vdst_set ? c1 : Set(c2) ∈ vdst_set ? c2 : c3

c1, c2, c3 = [dst(e)[1]..., dst(e)[2]...], [dst(e)[2]..., dst(e)[3]...], [dst(e)[1]..., dst(e)[3]...]
right_bipartition = Set(c1) ∈ vsrc_set ? c1 : Set(c2) ∈ vsrc_set ? c2 : c3

return left_bipartition, right_bipartition
end

"""Given a contraction node, get the keys living on all its neighbouring leaves"""
function external_node_keys(g::AbstractGraph, v)
return [Base.Iterators.flatten(v[findall(==(1), [length(vi) == 1 for vi in v])])...]
end

"""Given a contraction node, get all keys which are not living on a neighbouring leaf"""
function external_contraction_node_ext_keys(g::AbstractGraph, v)
return [Base.Iterators.flatten(v[findall(==(1), [length(vi) != 1 for vi in v])])...]
end
44 changes: 44 additions & 0 deletions test/test_contraction_sequence_to_graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using ITensorNetworks
using ITensorNetworks: contraction_sequence_to_graph, internal_edges, contraction_tree_leaf_bipartition, distance_to_leaf, leaf_vertices
using Test
using ITensors
using NamedGraphs

@testset "contraction_sequence_to_graph" begin

n =3
dims = (n,n)
g = named_grid(dims)
s = siteinds("S=1/2", g)

ψ = randomITensorNetwork(s; link_space=2)
ψψ = flatten_networks(ψ,ψ)

seq = contraction_sequence(ψψ);

g_seq = contraction_sequence_to_graph(seq)

#Get all leaf nodes (should match number of tensors in original network)
g_seq_leaves = leaf_vertices(g_seq)

@test length(g_seq_leaves) == n*n

for eb in internal_edges(g_seq)
vs = contraction_tree_leaf_bipartition(g_seq, eb)
@test length(vs) == 2
@test Set([v.I for v in vcat(vs[1],vs[2])]) == Set(vertices(ψψ))

end
#Check all internal vertices define a correct tripartition and all leaf vertices define a bipartition (tensor on that leafs vs tensor on rest of tree)
for v in vertices(g_seq)
if(!is_leaf(g_seq, v))
@test length(v) == 3
@test Set([vsi.I for vsi in vcat(v[1], v[2], v[3])]) == Set(vertices(ψψ))
else
@test length(v) == 2
@test Set([vsi.I for vsi in vcat(v[1], v[2])]) == Set(vertices(ψψ))
end

end

end