Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 52 additions & 9 deletions src/binary_tree_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,33 @@ function _root_union!(s::DisjointSets, x, y; left_root=true)
return s.revmap[_introot_union!(s.internal, s.intmap[x], s.intmap[y]; left_root=true)]
end

"""
Return the root vertex of a directed tree data graph
"""
@traitfn function _root(graph::AbstractDataGraph::IsDirected)
@assert is_tree(undirected_graph(underlying_graph(graph)))
v = vertices(graph)[1]
while parent_vertex(graph, v) != nothing
v = parent_vertex(graph, v)
end
return v
end

"""
Check if a data graph is a directed binary tree
"""
@traitfn function _is_directed_binary_tree(graph::AbstractDataGraph::IsDirected)
if !is_tree(undirected_graph(underlying_graph(graph)))
return false
end
for v in vertices(graph)
if !is_leaf(graph, v) && length(child_vertices(graph, v)) != 2
return false
end
end
return true
end

"""
Partition the input network containing both `tn` and `deltas` (a vector of delta tensors)
into two partitions, one adjacent to source_inds and the other adjacent to other external
Expand Down Expand Up @@ -134,31 +161,43 @@ Note: in the output partition, tensor vertex names will be changed. For a given
Note: for a given binary tree with n indices, the output partition will contain 2n-1 vertices,
with each leaf vertex corresponding to a sub tn adjacent to one output index. Keeping these
leaf vertices in the partition makes later `approx_itensornetwork` algorithms more efficient.
Note: name of vertices in the output partition can be different from the name of vertices
in `inds_btree`.
"""
function binary_tree_partition(tn::ITensorNetwork, inds_btree::Vector)
function partition(
::Algorithm"mincut_recursive_bisection", tn::ITensorNetwork, inds_btree::DataGraph
)
@assert _is_directed_binary_tree(inds_btree)
output_tns = Vector{ITensorNetwork}()
output_deltas_vector = Vector{Vector{ITensor}}()
# Mapping each vertex of the binary tree to a tn and a vector of deltas
# representing the partition of the subtree containing this vertex and
# its descendant vertices.
v_to_subtree_tn_deltas = Dict{Union{Vector,Index},Tuple}()
v_to_subtree_tn_deltas[inds_btree] = (tn, Vector{ITensor}())
for v in PreOrderDFS(inds_btree)
leaves = leaf_vertices(inds_btree)
root = _root(inds_btree)
v_to_subtree_tn_deltas = Dict{vertextype(inds_btree),Tuple}()
v_to_subtree_tn_deltas[root] = (tn, Vector{ITensor}())
for v in pre_order_dfs_vertices(inds_btree, root)
@assert haskey(v_to_subtree_tn_deltas, v)
input_tn, input_deltas = v_to_subtree_tn_deltas[v]
if v isa Index
if is_leaf(inds_btree, v)
push!(output_tns, input_tn)
push!(output_deltas_vector, input_deltas)
continue
end
c1, c2 = child_vertices(inds_btree, v)
descendant_c1 = pre_order_dfs_vertices(inds_btree, c1)
indices = [inds_btree[l] for l in intersect(descendant_c1, leaves)]
tn1, deltas1, input_tn, input_deltas = _binary_partition(
input_tn, input_deltas, collect(Leaves(v[1]))
input_tn, input_deltas, indices
)
v_to_subtree_tn_deltas[v[1]] = (tn1, deltas1)
v_to_subtree_tn_deltas[c1] = (tn1, deltas1)
descendant_c2 = pre_order_dfs_vertices(inds_btree, c2)
indices = [inds_btree[l] for l in intersect(descendant_c2, leaves)]
tn1, deltas1, input_tn, input_deltas = _binary_partition(
input_tn, input_deltas, collect(Leaves(v[2]))
input_tn, input_deltas, indices
)
v_to_subtree_tn_deltas[v[2]] = (tn1, deltas1)
v_to_subtree_tn_deltas[c2] = (tn1, deltas1)
push!(output_tns, input_tn)
push!(output_deltas_vector, input_deltas)
end
Expand All @@ -182,3 +221,7 @@ function binary_tree_partition(tn::ITensorNetwork, inds_btree::Vector)
tn_deltas = ITensorNetwork(vcat(output_deltas_vector...))
return partition(ITensorNetwork{Any}(disjoint_union(out_tn, tn_deltas)), subgraph_vs)
end

function partition(tn::ITensorNetwork, inds_btree::DataGraph; alg::String)
return partition(Algorithm(alg), tn, inds_btree)
end
4 changes: 2 additions & 2 deletions src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ export AbstractITensorNetwork,
tdvp,
to_vec

# ITensorNetworks: binary_tree_partition.jl
export binary_tree_partition
# ITensorNetworks: mincut.jl
export path_graph_structure, binary_tree_structure

# ITensorNetworks: lattices.jl
# TODO: DELETE
Expand Down
85 changes: 67 additions & 18 deletions src/mincut.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,36 @@
# a large number to prevent this edge being a cut
MAX_WEIGHT = 1e32

"""
Outputs a maximimally unbalanced directed binary tree DataGraph defining the desired graph structure
"""
function path_graph_structure(tn::ITensorNetwork)
return path_graph_structure(tn, noncommoninds(Vector{ITensor}(tn)...))
end

"""
Given a `tn` and `outinds` (a subset of noncommoninds of `tn`), outputs a maximimally unbalanced
directed binary tree DataGraph of `outinds` defining the desired graph structure
"""
function path_graph_structure(tn::ITensorNetwork, outinds::Vector{<:Index})
return _binary_tree_structure(tn, outinds; maximally_unbalanced=true)
end

"""
Outputs a directed binary tree DataGraph defining the desired graph structure
"""
function binary_tree_structure(tn::ITensorNetwork)
return binary_tree_structure(tn, noncommoninds(Vector{ITensor}(tn)...))
end

"""
Given a `tn` and `outinds` (a subset of noncommoninds of `tn`), outputs a
directed binary tree DataGraph of `outinds` defining the desired graph structure
"""
function binary_tree_structure(tn::ITensorNetwork, outinds::Vector{<:Index})
return _binary_tree_structure(tn, outinds; maximally_unbalanced=false)
end

"""
Calculate the mincut between two subsets of the uncontracted inds
(source_inds and terminal_inds) of the input tn.
Expand Down Expand Up @@ -87,36 +117,31 @@ function _maxweightoutinds_tn(tn::ITensorNetwork, outinds::Union{Nothing,Vector{
end

"""
Given a tn and outinds (a subset of noncommoninds of tn),
get a binary tree structure of outinds that will be used in the binary tree partition.
Given a tn and outinds (a subset of noncommoninds of tn), get a `DataGraph`
with binary tree structure of outinds that will be used in the binary tree partition.
If maximally_unbalanced=true, the binary tree will have a line/mps structure.
The binary tree is recursively constructed from leaves to the root.

Example:
# TODO
"""
function _binary_tree_structure(
tn::ITensorNetwork, outinds::Vector{<:Index}; maximally_unbalanced::Bool=false
)
inds_tree_vector = _binary_tree_partition_inds(
tn, outinds; maximally_unbalanced=maximally_unbalanced
)
return _nested_vector_to_directed_tree(inds_tree_vector)
end

function _binary_tree_partition_inds(
tn::ITensorNetwork,
outinds::Union{Nothing,Vector{<:Index}};
maximally_unbalanced::Bool=false,
tn::ITensorNetwork, outinds::Vector{<:Index}; maximally_unbalanced::Bool=false
)
if outinds == nothing
outinds = noncommoninds(Vector{ITensor}(tn)...)
end
if length(outinds) == 1
return outinds
end
maxweight_tn, out_to_maxweight_ind = _maxweightoutinds_tn(tn, outinds)
return __binary_tree_partition_inds(
tn => maxweight_tn, out_to_maxweight_ind; maximally_unbalanced=maximally_unbalanced
)
end

function __binary_tree_partition_inds(
tn_pair::Pair{<:ITensorNetwork,<:ITensorNetwork},
out_to_maxweight_ind::Dict{Index,Index};
maximally_unbalanced::Bool=false,
)
tn_pair = tn => maxweight_tn
if maximally_unbalanced == false
return _binary_tree_partition_inds_mincut(tn_pair, out_to_maxweight_ind)
else
Expand All @@ -126,6 +151,30 @@ function __binary_tree_partition_inds(
end
end

function _nested_vector_to_directed_tree(inds_tree_vector::Vector)
if length(inds_tree_vector) == 1 && inds_tree_vector[1] isa Index
inds_btree = DataGraph(NamedDiGraph([1]), Index)
inds_btree[1] = inds_tree_vector[1]
return inds_btree
end
treenode_to_v = Dict{Union{Vector,Index},Int}()
graph = DataGraph(NamedDiGraph(), Index)
v = 1
for treenode in PostOrderDFS(inds_tree_vector)
add_vertex!(graph, v)
treenode_to_v[treenode] = v
if treenode isa Index
graph[v] = treenode
else
@assert length(treenode) == 2
add_edge!(graph, v, treenode_to_v[treenode[1]])
add_edge!(graph, v, treenode_to_v[treenode[2]])
end
v += 1
end
return graph
end

"""
Given a tn and outinds, returns a vector of indices representing MPS inds ordering.
"""
Expand Down
29 changes: 12 additions & 17 deletions test/test_binary_tree_partition.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using ITensors
using ITensorNetworks:
_binary_tree_partition_inds, _mps_partition_inds_order, _mincut_partitions
_mps_partition_inds_order, _mincut_partitions, _is_directed_binary_tree

@testset "test mincut functions on top of MPS" begin
i = Index(2, "i")
Expand All @@ -15,12 +15,11 @@ using ITensorNetworks:
T = randomITensor(i, j, k, l, m, n, o, p)
M = MPS(T, (i, j, k, l, m, n, o, p); cutoff=1e-5, maxdim=500)
tn = ITensorNetwork(M[:])
out = _binary_tree_partition_inds(
tn, [i, j, k, l, m, n, o, p]; maximally_unbalanced=false
)
@test length(out) == 2
out = _binary_tree_partition_inds(tn, [i, j, k, l, m, n, o, p]; maximally_unbalanced=true)
@test length(out) == 2
for out in [binary_tree_structure(tn), path_graph_structure(tn)]
@test out isa DataGraph
@test _is_directed_binary_tree(out)
@test length(vertex_data(out).values) == 8
end
out = _mps_partition_inds_order(tn, [o, p, i, j, k, l, m, n])
@test out in [[i, j, k, l, m, n, o, p], [p, o, n, m, l, k, j, i]]
p1, p2 = _mincut_partitions(tn, [k, l], [m, n])
Expand All @@ -44,14 +43,11 @@ end
tn[v...] = network[v...]
end
tn = ITensorNetwork(vec(tn[:, :, 1]))
out = _binary_tree_partition_inds(
tn, noncommoninds(Vector{ITensor}(tn)...); maximally_unbalanced=false
)
@test length(out) == 2
out = _binary_tree_partition_inds(
tn, noncommoninds(Vector{ITensor}(tn)...); maximally_unbalanced=true
)
@test length(out) == 2
for out in [binary_tree_structure(tn), path_graph_structure(tn)]
@test out isa DataGraph
@test _is_directed_binary_tree(out)
@test length(vertex_data(out).values) == 9
end
end

@testset "test binary_tree_partition" begin
Expand All @@ -65,8 +61,7 @@ end
network = M[:]
out1 = contract(network...)
tn = ITensorNetwork(network)
inds_btree = _binary_tree_partition_inds(tn, [i, j, k, l, m]; maximally_unbalanced=false)
par = binary_tree_partition(tn, inds_btree)
par = partition(tn, binary_tree_structure(tn); alg="mincut_recursive_bisection")
networks = [Vector{ITensor}(par[v]) for v in vertices(par)]
network2 = vcat(networks...)
out2 = contract(network2...)
Expand Down