Skip to content

[binary_tree_partition] [1/2]: Add mincut helper functions and introduce _binary_tree_partition_inds #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ version = "0.3.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
Expand All @@ -28,11 +30,13 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

[compat]
AbstractTrees = "0.4.4"
Combinatorics = "1"
Compat = "3, 4"
DataGraphs = "0.1.7"
Dictionaries = "0.3.15"
DocStringExtensions = "0.8, 0.9"
Graphs = "1.6"
GraphsFlows = "0.1.1"
ITensors = "0.3.23"
IsApprox = "0.1.7"
IterTools = "1.4.0"
Expand Down
3 changes: 3 additions & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
module ITensorNetworks

using AbstractTrees
using Combinatorics
using Compat
using DataGraphs
using Dictionaries
using DocStringExtensions
using Graphs
using GraphsFlows
using Graphs.SimpleGraphs # AbstractSimpleGraph
using IsApprox
using ITensors
Expand Down Expand Up @@ -77,6 +79,7 @@ include("expect.jl")
include("models.jl")
include("tebd.jl")
include("itensornetwork.jl")
include("mincut.jl")
include("utility.jl")
include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ export Key,
incident_edges,
comb_tree,
named_comb_tree,
subgraph
subgraph,
mincut_partitions

# DataGraphs
export DataGraph, vertex_data, edge_data, underlying_graph
Expand Down
1 change: 1 addition & 0 deletions src/imports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import NamedGraphs:
vertex_to_parent_vertex,
rename_vertices,
disjoint_union,
mincut_partitions,
incident_edges

import .DataGraphs:
Expand Down
225 changes: 225 additions & 0 deletions src/mincut.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# a large number to prevent this edge being a cut
MAX_WEIGHT = 1e32

"""
Calculate the mincut between two subsets of the uncontracted inds
(source_inds and terminal_inds) of the input tn.
Mincut of two inds list is defined as the mincut of two newly added vertices,
each one neighboring to one inds subset.
"""
function _mincut(
tn::ITensorNetwork, source_inds::Vector{<:Index}, terminal_inds::Vector{<:Index}
)
@assert length(source_inds) >= 1
@assert length(terminal_inds) >= 1
noncommon_inds = noncommoninds(Vector{ITensor}(tn)...)
@assert issubset(source_inds, noncommon_inds)
@assert issubset(terminal_inds, noncommon_inds)
tn = disjoint_union(
ITensorNetwork([ITensor(source_inds...), ITensor(terminal_inds...)]), tn
)
return GraphsFlows.mincut(tn, (1, 1), (2, 1), weights(tn))
end

"""
Calculate the mincut_partitions between two subsets of the uncontracted inds
(source_inds and terminal_inds) of the input tn.
"""
function _mincut_partitions(
tn::ITensorNetwork, source_inds::Vector{<:Index}, terminal_inds::Vector{<:Index}
)
p1, p2, cut = _mincut(tn, source_inds, terminal_inds)
p1 = [v[1] for v in p1 if v[2] == 2]
p2 = [v[1] for v in p2 if v[2] == 2]
return p1, p2
end

"""
Sum of shortest path distances among all outinds.
"""
function _distance(tn::ITensorNetwork, outinds::Vector{<:Index})
@assert length(outinds) >= 1
@assert issubset(outinds, noncommoninds(Vector{ITensor}(tn)...))
if length(outinds) == 1
return 0.0
end
new_tensors = [ITensor(i) for i in outinds]
tn = disjoint_union(ITensorNetwork(new_tensors), tn)
distances = 0.0
for i in 1:(length(new_tensors) - 1)
ds = dijkstra_shortest_paths(tn, [(i, 1)], weights(tn))
for j in (i + 1):length(new_tensors)
distances += ds.dists[(j, 1)]
end
end
return distances
end

"""
create a tn with empty ITensors whose outinds weights are MAX_WEIGHT
The maxweight_tn is constructed so that only commoninds of the tn
will be considered in mincut.
"""
function _maxweightoutinds_tn(tn::ITensorNetwork, outinds::Union{Nothing,Vector{<:Index}})
@assert issubset(outinds, noncommoninds(Vector{ITensor}(tn)...))
out_to_maxweight_ind = Dict{Index,Index}()
for ind in outinds
out_to_maxweight_ind[ind] = Index(MAX_WEIGHT, ind.tags)
end
maxweight_tn = copy(tn)
for v in vertices(maxweight_tn)
t = maxweight_tn[v]
inds1 = [i for i in inds(t) if !(i in outinds)]
inds2 = [out_to_maxweight_ind[i] for i in inds(t) if i in outinds]
newt = ITensor(inds1..., inds2...)
maxweight_tn[v] = newt
end
return maxweight_tn, out_to_maxweight_ind
end

"""
Given a tn and outinds (a subset of noncommoninds of tn),
get a binary tree structure of outinds that will be used in the binary tree partition.
If maximally_unbalanced=true, the binary tree will have a line/mps structure.
The binary tree is recursively constructed from leaves to the root.

Example:
# TODO
"""
function _binary_tree_partition_inds(
tn::ITensorNetwork,
outinds::Union{Nothing,Vector{<:Index}};
maximally_unbalanced::Bool=false,
)
if outinds == nothing
outinds = noncommoninds(Vector{ITensor}(tn)...)
end
if length(outinds) == 1
return outinds
end
maxweight_tn, out_to_maxweight_ind = _maxweightoutinds_tn(tn, outinds)
return __binary_tree_partition_inds(
tn => maxweight_tn, out_to_maxweight_ind; maximally_unbalanced=maximally_unbalanced
)
end

function __binary_tree_partition_inds(
tn_pair::Pair{<:ITensorNetwork,<:ITensorNetwork},
out_to_maxweight_ind::Dict{Index,Index};
maximally_unbalanced::Bool=false,
)
if maximally_unbalanced == false
return _binary_tree_partition_inds_mincut(tn_pair, out_to_maxweight_ind)
else
return line_to_tree(
_binary_tree_partition_inds_maximally_unbalanced(tn_pair, out_to_maxweight_ind)
)
end
end

"""
Given a tn and outinds, returns a vector of indices representing MPS inds ordering.
"""
function _mps_partition_inds_order(
tn::ITensorNetwork, outinds::Union{Nothing,Vector{<:Index}}
)
if outinds == nothing
outinds = noncommoninds(Vector{ITensor}(tn)...)
end
if length(outinds) == 1
return outinds
end
tn2, out_to_maxweight_ind = _maxweightoutinds_tn(tn, outinds)
return _binary_tree_partition_inds_maximally_unbalanced(tn => tn2, out_to_maxweight_ind)
end

function _binary_tree_partition_inds_maximally_unbalanced(
tn_pair::Pair{<:ITensorNetwork,<:ITensorNetwork}, out_to_maxweight_ind::Dict{Index,Index}
)
outinds = collect(keys(out_to_maxweight_ind))
@assert length(outinds) >= 1
if length(outinds) <= 2
return outinds
end
first_inds, _ = _mincut_inds(
tn_pair, out_to_maxweight_ind, collect(powerset(outinds, 1, 1))
)
first_ind = first_inds[1]
linear_order = [first_ind]
outinds = setdiff(outinds, linear_order)
while length(outinds) > 1
sourceinds_list = [Vector{Index}([linear_order..., i]) for i in outinds]
target_inds, _ = _mincut_inds(tn_pair, out_to_maxweight_ind, sourceinds_list)
new_ind = setdiff(target_inds, linear_order)[1]
push!(linear_order, new_ind)
outinds = setdiff(outinds, [new_ind])
end
push!(linear_order, outinds[1])
return linear_order
end

function _binary_tree_partition_inds_mincut(
tn_pair::Pair{<:ITensorNetwork,<:ITensorNetwork}, out_to_maxweight_ind::Dict{Index,Index}
)
outinds = collect(keys(out_to_maxweight_ind))
@assert length(outinds) >= 1
if length(outinds) <= 2
return outinds
end
while length(outinds) > 2
tree_list = collect(powerset(outinds, 2, 2))
sourceinds_list = [collect(Leaves(i)) for i in tree_list]
_, i = _mincut_inds(tn_pair, out_to_maxweight_ind, sourceinds_list)
tree = tree_list[i]
outinds = setdiff(outinds, tree)
outinds = vcat([tree], outinds)
end
return outinds
end

"""
Find a vector of indices within sourceinds_list yielding the mincut of given tn_pair.
Args:
tn_pair: a pair of tns (tn1 => tn2), where tn2 is generated via _maxweightoutinds_tn(tn1)
out_to_maxweight_ind: a dict mapping each out ind in tn1 to out ind in tn2
sourceinds_list: a list of vector of indices to be considered
Note:
For each sourceinds in sourceinds_list, we consider its mincut within both tns (tn1, tn2) given in tn_pair.
The mincut in tn1 represents the rank upper bound when splitting sourceinds with other inds in outinds.
The mincut in tn2 represents the rank upper bound when the weights of outinds are very large.
The first mincut upper_bounds the number of non-zero singular values, while the second empirically reveals the
singular value decay.
We output the sourceinds where the first mincut value is the minimum, the secound mincut value is also
the minimum under the condition that the first mincut is optimal, and the sourceinds have the lowest all-pair shortest path.
"""
function _mincut_inds(
tn_pair::Pair{<:ITensorNetwork,<:ITensorNetwork},
out_to_maxweight_ind::Dict{Index,Index},
sourceinds_list::Vector{<:Vector{<:Index}},
)
function _mincut_value(tn, sinds, outinds)
tinds = setdiff(outinds, sinds)
_, _, cut = _mincut(tn, sinds, tinds)
return cut
end
function _get_weights(source_inds, outinds, maxweight_source_inds, maxweight_outinds)
mincut_val = _mincut_value(tn_pair.first, source_inds, outinds)
maxweight_mincut_val = _mincut_value(
tn_pair.second, maxweight_source_inds, maxweight_outinds
)
dist = _distance(tn_pair.first, source_inds)
return (mincut_val, maxweight_mincut_val, dist)
end

outinds = collect(keys(out_to_maxweight_ind))
maxweight_outinds = collect(values(out_to_maxweight_ind))
weights = []
for source_inds in sourceinds_list
maxweight_source_inds = [out_to_maxweight_ind[i] for i in source_inds]
push!(
weights, _get_weights(source_inds, outinds, maxweight_source_inds, maxweight_outinds)
)
end
i = argmin(weights)
return sourceinds_list[i], i
end
11 changes: 11 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,14 @@ maybe_only(x::Tuple{T}) where {T} = only(x)

front(itr, n=1) = Iterators.take(itr, length(itr) - n)
tail(itr) = Iterators.drop(itr, 1)

# Tree utils
function line_to_tree(line::Vector)
if length(line) == 1 && line[1] isa Vector
return line[1]
end
if length(line) <= 2
return line
end
return [line_to_tree(line[1:(end - 1)]), line[end]]
end
55 changes: 55 additions & 0 deletions test/test_mincut.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
using ITensors
using ITensorNetworks:
_binary_tree_partition_inds, _mps_partition_inds_order, _mincut_partitions

@testset "test mincut functions on top of MPS" begin
i = Index(2, "i")
j = Index(2, "j")
k = Index(2, "k")
l = Index(2, "l")
m = Index(2, "m")
n = Index(2, "n")
o = Index(2, "o")
p = Index(2, "p")

T = randomITensor(i, j, k, l, m, n, o, p)
M = MPS(T, (i, j, k, l, m, n, o, p); cutoff=1e-5, maxdim=500)
tn = ITensorNetwork(M[:])
out = _binary_tree_partition_inds(
tn, [i, j, k, l, m, n, o, p]; maximally_unbalanced=false
)
@test length(out) == 2
out = _binary_tree_partition_inds(tn, [i, j, k, l, m, n, o, p]; maximally_unbalanced=true)
@test length(out) == 2
out = _mps_partition_inds_order(tn, [o, p, i, j, k, l, m, n])
@test out in [[i, j, k, l, m, n, o, p], [p, o, n, m, l, k, j, i]]
p1, p2 = _mincut_partitions(tn, [k, l], [m, n])
# When MPS bond dimensions are large, the partition will not across internal inds
@test (length(p1) == 0) || (length(p2) == 0)

M = MPS(T, (i, j, k, l, m, n, o, p); cutoff=1e-5, maxdim=2)
tn = ITensorNetwork(M[:])
p1, p2 = _mincut_partitions(tn, [k, l], [m, n])
# When MPS bond dimensions are small, the partition will across internal inds
@test sort(p1) == [1, 2, 3, 4]
@test sort(p2) == [5, 6, 7, 8]
end

@testset "test inds_binary_tree of a 2D network" begin
N = (3, 3, 3)
linkdim = 2
network = randomITensorNetwork(IndsNetwork(named_grid(N)); link_space=linkdim)
tn = Array{ITensor,length(N)}(undef, N...)
for v in vertices(network)
tn[v...] = network[v...]
end
tn = ITensorNetwork(vec(tn[:, :, 1]))
out = _binary_tree_partition_inds(
tn, noncommoninds(Vector{ITensor}(tn)...); maximally_unbalanced=false
)
@test length(out) == 2
out = _binary_tree_partition_inds(
tn, noncommoninds(Vector{ITensor}(tn)...); maximally_unbalanced=true
)
@test length(out) == 2
end