Skip to content

Contraction tree to graph #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jan 25, 2023
40 changes: 40 additions & 0 deletions src/Graphs/abstractgraph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Determine if an edge involves a leaf (at src or dst)"""
function is_leaf_edge(g::AbstractGraph, e)
return is_leaf(g, src(e)) || is_leaf(g, dst(e))
end

"""Determine if a node has any neighbors which are leaves"""
function has_leaf_neighbor(g::AbstractGraph, v)
for vn in neighbors(g, v)
if(is_leaf(g, vn))
return true
end
end
return false
end

"""Get all edges which do not involve a leaf"""
function internal_edges(g::AbstractGraph)
return filter(e -> !is_leaf_edge(g, e), edges(g))
end

"""Get all vertices which are leaves of a graph"""
function leaf_vertices(g::AbstractGraph)
return vertices(g)[findall(==(1), [is_leaf(g,v) for v in vertices(g)])]
end

"""Get distance of a vertex from a leaf"""
function distance_to_leaf(g::AbstractGraph, v)
leaves = leaf_vertices(g)
if(isempty(leaves))
println("ERROR: GRAPH DOES NTO CONTAIN LEAVES")
return NaN
end

return minimum([length(a_star(g, v, leaf)) for leaf in leaves])
end

"""Return all vertices which are within a certain pathlength `dist` of the leaves of the graph"""
function distance_from_roots(g::AbstractGraph, dist::Int64)
return vertices(g)[findall(<=(dist), [distance_to_leaf(g, v) for v in vertices(g)])]
end
2 changes: 2 additions & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include("beliefpropagation.jl")
include("contraction_tree_to_graph.jl")
include(joinpath("Graphs", "abstractgraph.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttn.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttn.jl"))
Expand Down
70 changes: 70 additions & 0 deletions src/contraction_tree_to_graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@

"""
Take a contraction_sequence and return a graphical representation of it. The leaves of the graph represent the leaves of the sequence whilst the internal_nodes of the graph
define a tripartition of the graph and thus are named as an n = 3 element tuples, which each element specifying the keys involved.
Edges connect parents/children within the contraction sequence.
"""
function contraction_sequence_to_graph(contract_sequence)

g = fill_contraction_sequence_graph_vertices(contract_sequence)

#Now we have the vertices we need to figure out the edges
for v in vertices(g)
#Only add edges from a parent (which defines a tripartition and thus has length 3) to its children
if(length(v) == 3)
#Work out which vertices it connects to
concat1, concat2, concat3 =[v[1]..., v[2]...], [v[2]..., v[3]...], [v[1]..., v[3]...]
for vn in setdiff(vertices(g), [v])
vn_set = [Set(vni) for vni in vn]
if(Set(concat1) ∈ vn_set || Set(concat2) ∈ vn_set || Set(concat3) ∈ vn_set)
add_edge!(g, v => vn)
end
end
end
end


return g
end


function fill_contraction_sequence_graph_vertices(contract_sequence)
g = NamedGraph()
leaves = collect(Leaves(contract_sequence))
fill_contraction_sequence_graph_vertices!(g, contract_sequence[1], leaves)
fill_contraction_sequence_graph_vertices!(g, contract_sequence[2], leaves)
return g
end

"""Given a contraction sequence which is a subsequence of some larger sequence (with leaves `leaves`) which is being built on g
Spawn `contract sequence' as a vertex on `current_g' and continue on with its children """
function fill_contraction_sequence_graph_vertices!(g, contract_sequence, leaves)
if(isa(contract_sequence, Array))
group1 = collect(Leaves(contract_sequence[1]))
group2 = collect(Leaves(contract_sequence[2]))
remaining_verts = setdiff(leaves, vcat(group1, group2))
add_vertex!(g, (group1, group2, remaining_verts))
fill_contraction_sequence_graph_vertices!(g, contract_sequence[1], leaves)
fill_contraction_sequence_graph_vertices!(g, contract_sequence[2], leaves)
else
add_vertex!(g, ([contract_sequence], setdiff(leaves, [contract_sequence])))
end
end

"""Get the vertex bi-partition that a given edge represents"""
function contraction_tree_leaf_bipartition(g::AbstractGraph, e)

if(!is_leaf_edge(g, e))
vsrc_set, vdst_set = [Set(vni) for vni in src(e)], [Set(vni) for vni in dst(e)]
c1, c2, c3 = [src(e)[1]..., src(e)[2]...], [src(e)[2]..., src(e)[3]...], [src(e)[1]..., src(e)[3]...]
left_bipartition = Set(c1) ∈ vdst_set ? c1 : Set(c2) ∈ vdst_set ? c2 : c3

c1, c2, c3 = [dst(e)[1]..., dst(e)[2]...], [dst(e)[2]..., dst(e)[3]...], [dst(e)[1]..., dst(e)[3]...]
right_bipartition = Set(c1) ∈ vsrc_set ? c1 : Set(c2) ∈ vsrc_set ? c2 : c3
else
left_bipartition = filter(vs -> Set(vs) ∈ [Set(vni) for vni in dst(e)], src(e))[1]
right_bipartition = setdiff(src(e), left_bipartition)
end

return left_bipartition, right_bipartition
end
44 changes: 44 additions & 0 deletions test/test_contraction_sequence_to_graph.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
using ITensorNetworks
using ITensorNetworks: contraction_sequence_to_graph, internal_edges, contraction_tree_leaf_bipartition, distance_to_leaf, leaf_vertices
using Test
using ITensors
using NamedGraphs

@testset "contraction_sequence_to_graph" begin

n =3
dims = (n,n)
g = named_grid(dims)
s = siteinds("S=1/2", g)

ψ = randomITensorNetwork(s; link_space=2)
ψψ = flatten_networks(ψ,ψ)

seq = contraction_sequence(ψψ);

g_seq = contraction_sequence_to_graph(seq)

#Get all leaf nodes (should match number of tensors in original network)
g_seq_leaves = leaf_vertices(g_seq)

@test length(g_seq_leaves) == n*n

for eb in internal_edges(g_seq)
vs = contraction_tree_leaf_bipartition(g_seq, eb)
@test length(vs) == 2
@test Set([v.I for v in vcat(vs[1],vs[2])]) == Set(vertices(ψψ))

end
#Check all internal vertices define a correct tripartition and all leaf vertices define a bipartition (tensor on that leafs vs tensor on rest of tree)
for v in vertices(g_seq)
if(!is_leaf(g_seq, v))
@test length(v) == 3
@test Set([vsi.I for vsi in vcat(v[1], v[2], v[3])]) == Set(vertices(ψψ))
else
@test length(v) == 2
@test Set([vsi.I for vsi in vcat(v[1], v[2])]) == Set(vertices(ψψ))
end

end

end