Skip to content

Commit 536faaa

Browse files
authored
Introduce _contract_deltas for removing trivial delta from a tensor network
1 parent e82baef commit 536faaa

File tree

6 files changed

+240
-144
lines changed

6 files changed

+240
-144
lines changed

src/Graphs/abstractgraph.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ function internal_edges(g::AbstractGraph)
1818
return filter(e -> !is_leaf_edge(g, e), edges(g))
1919
end
2020

21-
"""Get all vertices which are leaves of a graph"""
22-
function leaf_vertices(g::AbstractGraph)
23-
return vertices(g)[findall(==(1), [is_leaf(g, v) for v in vertices(g)])]
24-
end
25-
2621
"""Get distance of a vertex from a leaf"""
2722
function distance_to_leaf(g::AbstractGraph, v)
2823
leaves = leaf_vertices(g)

src/ITensorNetworks.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ include("models.jl")
8383
include("tebd.jl")
8484
include("itensornetwork.jl")
8585
include("mincut.jl")
86+
include("contract_deltas.jl")
8687
include("binary_tree_partition.jl")
8788
include("utility.jl")
8889
include("specialitensornetworks.jl")

src/binary_tree_partition.jl

Lines changed: 32 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -1,126 +1,25 @@
1-
"""
2-
Rewrite of the function
3-
`DataStructures.root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer}`.
4-
"""
5-
function _introot_union!(s::DataStructures.IntDisjointSets, x, y; left_root=true)
6-
parents = s.parents
7-
rks = s.ranks
8-
@inbounds xrank = rks[x]
9-
@inbounds yrank = rks[y]
10-
if !left_root
11-
x, y = y, x
12-
end
13-
@inbounds parents[y] = x
14-
s.ngroups -= 1
15-
return x
16-
end
17-
18-
"""
19-
Rewrite of the function `DataStructures.root_union!(s::DisjointSet{T}, x::T, y::T)`.
20-
The difference is that in the output of `_root_union!`, x is guaranteed to be the root of y when
21-
setting `left_root=true`, and y will be the root of x when setting `left_root=false`.
22-
In `DataStructures.root_union!`, the root value cannot be specified.
23-
A specified root is useful in functions such as `_remove_deltas`, where when we union two
24-
indices into one disjointset, we want the index that is the outinds if the given tensor network
25-
to always be the root in the DisjointSets.
26-
"""
27-
function _root_union!(s::DisjointSets, x, y; left_root=true)
28-
return s.revmap[_introot_union!(s.internal, s.intmap[x], s.intmap[y]; left_root=true)]
29-
end
30-
31-
"""
32-
Partition the input network containing both `tn` and `deltas` (a vector of delta tensors)
33-
into two partitions, one adjacent to source_inds and the other adjacent to other external
34-
inds of the network.
35-
"""
36-
function _binary_partition(
37-
tn::ITensorNetwork, deltas::Vector{ITensor}, source_inds::Vector{<:Index}
38-
)
39-
all_tensors = [Vector{ITensor}(tn)..., deltas...]
40-
external_inds = noncommoninds(all_tensors...)
1+
function _binary_partition(tn::ITensorNetwork, source_inds::Vector{<:Index})
2+
external_inds = noncommoninds(Vector{ITensor}(tn)...)
413
# add delta tensor to each external ind
424
external_sim_ind = [sim(ind) for ind in external_inds]
5+
tn = map_data(t -> replaceinds(t, external_inds => external_sim_ind), tn; edges=[])
6+
tn_wo_deltas = rename_vertices(v -> v[1], subgraph(v -> v[2] == 1, tn))
7+
deltas = Vector{ITensor}(subgraph(v -> v[2] == 2, tn))
438
new_deltas = [
449
delta(external_inds[i], external_sim_ind[i]) for i in 1:length(external_inds)
4510
]
46-
deltas = map(t -> replaceinds(t, external_inds => external_sim_ind), deltas)
4711
deltas = [deltas..., new_deltas...]
48-
tn = map_data(t -> replaceinds(t, external_inds => external_sim_ind), tn; edges=[])
12+
tn = disjoint_union(tn_wo_deltas, ITensorNetwork(deltas))
4913
p1, p2 = _mincut_partition_maxweightoutinds(
50-
disjoint_union(tn, ITensorNetwork(deltas)),
51-
source_inds,
52-
setdiff(external_inds, source_inds),
14+
tn, source_inds, setdiff(external_inds, source_inds)
5315
)
54-
tn_vs = [v[1] for v in p1 if v[2] == 1]
55-
source_tn = subgraph(tn, tn_vs)
56-
delta_indices = [v[1] for v in p1 if v[2] == 2]
57-
source_deltas = Vector{ITensor}([deltas[i] for i in delta_indices])
58-
source_tn, source_deltas = _remove_deltas(source_tn, source_deltas)
59-
tn_vs = [v[1] for v in p2 if v[2] == 1]
60-
remain_tn = subgraph(tn, tn_vs)
61-
delta_indices = [v[1] for v in p2 if v[2] == 2]
62-
remain_deltas = Vector{ITensor}([deltas[i] for i in delta_indices])
63-
remain_tn, remain_deltas = _remove_deltas(remain_tn, remain_deltas)
16+
source_tn = _contract_deltas(subgraph(tn, p1))
17+
remain_tn = _contract_deltas(subgraph(tn, p2))
6418
@assert (
65-
length(noncommoninds(all_tensors...)) == length(
66-
noncommoninds(
67-
Vector{ITensor}(source_tn)...,
68-
source_deltas...,
69-
Vector{ITensor}(remain_tn)...,
70-
remain_deltas...,
71-
),
72-
)
73-
)
74-
return source_tn, source_deltas, remain_tn, remain_deltas
75-
end
76-
77-
"""
78-
Given an input tensor network containing tensors in the input `tn`` and
79-
tensors in `deltas``, remove redundent delta tensors in `deltas` and change
80-
inds accordingly to make the output `tn` and `out_deltas` represent the same
81-
tensor network but with less delta tensors.
82-
Note: inds of tensors in `tn` and `deltas` may be changed, and `out_deltas`
83-
may still contain necessary delta tensors.
84-
85-
========
86-
Example:
87-
julia> is = [Index(2, "i") for i in 1:6]
88-
julia> a = ITensor(is[1], is[2])
89-
julia> b = ITensor(is[2], is[3])
90-
julia> delta1 = delta(is[3], is[4])
91-
julia> delta2 = delta(is[5], is[6])
92-
julia> tn = ITensorNetwork([a,b])
93-
julia> tn, out_deltas = ITensorNetworks._remove_deltas(tn, [delta1, delta2])
94-
julia> noncommoninds(Vector{ITensor}(tn)...)
95-
2-element Vector{Index{Int64}}:
96-
(dim=2|id=339|"1")
97-
(dim=2|id=489|"4")
98-
julia> length(out_deltas)
99-
1
100-
"""
101-
function _remove_deltas(tn::ITensorNetwork, deltas::Vector{ITensor})
102-
out_delta_inds = Vector{Pair}()
103-
network = [Vector{ITensor}(tn)..., deltas...]
104-
outinds = noncommoninds(network...)
105-
inds_list = map(t -> collect(inds(t)), deltas)
106-
deltainds = collect(Set(vcat(inds_list...)))
107-
ds = DisjointSets(deltainds)
108-
for t in deltas
109-
i1, i2 = inds(t)
110-
if find_root!(ds, i1) in outinds && find_root!(ds, i2) in outinds
111-
push!(out_delta_inds, find_root!(ds, i1) => find_root!(ds, i2))
112-
end
113-
if find_root!(ds, i1) in outinds
114-
_root_union!(ds, find_root!(ds, i1), find_root!(ds, i2))
115-
else
116-
_root_union!(ds, find_root!(ds, i2), find_root!(ds, i1))
117-
end
118-
end
119-
tn = map_data(
120-
t -> replaceinds(t, deltainds => [find_root!(ds, i) for i in deltainds]), tn; edges=[]
19+
length(external_inds) ==
20+
length(noncommoninds(Vector{ITensor}(source_tn)..., Vector{ITensor}(remain_tn)...))
12121
)
122-
out_deltas = Vector{ITensor}([delta(i.first, i.second) for i in out_delta_inds])
123-
return tn, out_deltas
22+
return source_tn, remain_tn
12423
end
12524

12625
"""
@@ -143,36 +42,30 @@ function partition(
14342
@assert _is_rooted_directed_binary_tree(inds_btree)
14443
output_tns = Vector{ITensorNetwork}()
14544
output_deltas_vector = Vector{Vector{ITensor}}()
146-
# Mapping each vertex of the binary tree to a tn and a vector of deltas
147-
# representing the partition of the subtree containing this vertex and
148-
# its descendant vertices.
45+
# Mapping each vertex of the binary tree to a tn representing the partition
46+
# of the subtree containing this vertex and its descendant vertices.
14947
leaves = leaf_vertices(inds_btree)
15048
root = _root(inds_btree)
151-
v_to_subtree_tn_deltas = Dict{vertextype(inds_btree),Tuple}()
152-
v_to_subtree_tn_deltas[root] = (tn, Vector{ITensor}())
49+
v_to_subtree_tn = Dict{vertextype(inds_btree),ITensorNetwork}()
50+
v_to_subtree_tn[root] = disjoint_union(tn, ITensorNetwork())
15351
for v in pre_order_dfs_vertices(inds_btree, root)
154-
@assert haskey(v_to_subtree_tn_deltas, v)
155-
input_tn, input_deltas = v_to_subtree_tn_deltas[v]
156-
if is_leaf(inds_btree, v)
157-
push!(output_tns, input_tn)
158-
push!(output_deltas_vector, input_deltas)
159-
continue
52+
@assert haskey(v_to_subtree_tn, v)
53+
input_tn = v_to_subtree_tn[v]
54+
if !is_leaf(inds_btree, v)
55+
c1, c2 = child_vertices(inds_btree, v)
56+
descendant_c1 = pre_order_dfs_vertices(inds_btree, c1)
57+
indices = [inds_btree[l] for l in intersect(descendant_c1, leaves)]
58+
tn1, input_tn = _binary_partition(input_tn, indices)
59+
v_to_subtree_tn[c1] = tn1
60+
descendant_c2 = pre_order_dfs_vertices(inds_btree, c2)
61+
indices = [inds_btree[l] for l in intersect(descendant_c2, leaves)]
62+
tn1, input_tn = _binary_partition(input_tn, indices)
63+
v_to_subtree_tn[c2] = tn1
16064
end
161-
c1, c2 = child_vertices(inds_btree, v)
162-
descendant_c1 = pre_order_dfs_vertices(inds_btree, c1)
163-
indices = [inds_btree[l] for l in intersect(descendant_c1, leaves)]
164-
tn1, deltas1, input_tn, input_deltas = _binary_partition(
165-
input_tn, input_deltas, indices
166-
)
167-
v_to_subtree_tn_deltas[c1] = (tn1, deltas1)
168-
descendant_c2 = pre_order_dfs_vertices(inds_btree, c2)
169-
indices = [inds_btree[l] for l in intersect(descendant_c2, leaves)]
170-
tn1, deltas1, input_tn, input_deltas = _binary_partition(
171-
input_tn, input_deltas, indices
172-
)
173-
v_to_subtree_tn_deltas[c2] = (tn1, deltas1)
174-
push!(output_tns, input_tn)
175-
push!(output_deltas_vector, input_deltas)
65+
tn = rename_vertices(u -> u[1], subgraph(u -> u[2] == 1, input_tn))
66+
deltas = Vector{ITensor}(subgraph(u -> u[2] == 2, input_tn))
67+
push!(output_tns, tn)
68+
push!(output_deltas_vector, deltas)
17669
end
17770
# In subgraph_vertices, each element is a vector of vertices to be
17871
# grouped in one partition.

src/contract_deltas.jl

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
"""
2+
Rewrite of the function
3+
`DataStructures.root_union!(s::IntDisjointSet{T}, x::T, y::T) where {T<:Integer}`.
4+
"""
5+
function _introot_union!(s::DataStructures.IntDisjointSets, x, y; left_root=true)
6+
parents = s.parents
7+
rks = s.ranks
8+
@inbounds xrank = rks[x]
9+
@inbounds yrank = rks[y]
10+
if !left_root
11+
x, y = y, x
12+
end
13+
@inbounds parents[y] = x
14+
s.ngroups -= 1
15+
return x
16+
end
17+
18+
"""
19+
Rewrite of the function `DataStructures.root_union!(s::DisjointSet{T}, x::T, y::T)`.
20+
The difference is that in the output of `_root_union!`, x is guaranteed to be the root of y when
21+
setting `left_root=true`, and y will be the root of x when setting `left_root=false`.
22+
In `DataStructures.root_union!`, the root value cannot be specified.
23+
A specified root is useful in functions such as `_remove_deltas`, where when we union two
24+
indices into one disjointset, we want the index that is the outinds if the given tensor network
25+
to always be the root in the DisjointSets.
26+
"""
27+
function _root_union!(s::DisjointSets, x, y; left_root=true)
28+
return s.revmap[_introot_union!(s.internal, s.intmap[x], s.intmap[y]; left_root=true)]
29+
end
30+
31+
"""
32+
Given a list of delta tensors `deltas`, return a `DisjointSets` of all its indices
33+
such that each pair of indices adjacent to any delta tensor must be in the same disjoint set.
34+
If a disjoint set contains indices in `rootinds`, then one of such indices in `rootinds`
35+
must be the root of this set.
36+
"""
37+
function _delta_inds_disjointsets(deltas::Vector{<:ITensor}, rootinds::Vector{<:Index})
38+
if deltas == []
39+
return DisjointSets()
40+
end
41+
inds_list = map(t -> collect(inds(t)), deltas)
42+
deltainds = collect(Set(vcat(inds_list...)))
43+
ds = DisjointSets(deltainds)
44+
for t in deltas
45+
i1, i2 = inds(t)
46+
if find_root!(ds, i1) in rootinds
47+
_root_union!(ds, find_root!(ds, i1), find_root!(ds, i2))
48+
else
49+
_root_union!(ds, find_root!(ds, i2), find_root!(ds, i1))
50+
end
51+
end
52+
return ds
53+
end
54+
55+
"""
56+
Given an input tensor network `tn`, remove redundent delta tensors
57+
in `tn` and change inds accordingly to make the output `tn` represent
58+
the same tensor network but with less delta tensors.
59+
60+
========
61+
Example:
62+
julia> is = [Index(2, string(i)) for i in 1:6]
63+
julia> a = ITensor(is[1], is[2])
64+
julia> b = ITensor(is[2], is[3])
65+
julia> delta1 = delta(is[3], is[4])
66+
julia> delta2 = delta(is[5], is[6])
67+
julia> tn = ITensorNetwork([a, b, delta1, delta2])
68+
julia> ITensorNetworks._contract_deltas(tn)
69+
ITensorNetwork{Int64} with 3 vertices:
70+
3-element Vector{Int64}:
71+
1
72+
2
73+
4
74+
75+
and 1 edge(s):
76+
1 => 2
77+
78+
with vertex data:
79+
3-element Dictionaries.Dictionary{Int64, Any}
80+
1 │ ((dim=2|id=457|"1"), (dim=2|id=296|"2"))
81+
2 │ ((dim=2|id=296|"2"), (dim=2|id=613|"4"))
82+
4 │ ((dim=2|id=626|"6"), (dim=2|id=237|"5"))
83+
"""
84+
function _contract_deltas(tn::ITensorNetwork)
85+
tn = copy(tn)
86+
network = Vector{ITensor}(tn)
87+
deltas = filter(t -> is_delta(t), network)
88+
outinds = noncommoninds(network...)
89+
ds = _delta_inds_disjointsets(deltas, outinds)
90+
deltainds = [ds...]
91+
sim_deltainds = [find_root!(ds, i) for i in deltainds]
92+
# `rem_vertex!(tn, v)` changes `vertices(tn)` in place.
93+
# We copy it here so that the enumeration won't be affected.
94+
for v in copy(vertices(tn))
95+
if !is_delta(tn[v])
96+
tn[v] = replaceinds(tn[v], deltainds, sim_deltainds)
97+
continue
98+
end
99+
i1, i2 = inds(tn[v])
100+
root = find_root!(ds, i1)
101+
@assert root === find_root!(ds, i2)
102+
if i1 != root && i1 in outinds
103+
tn[v] = delta(i1, root)
104+
elseif i2 != root && i2 in outinds
105+
tn[v] = delta(i2, root)
106+
else
107+
rem_vertex!(tn, v)
108+
end
109+
end
110+
return tn
111+
end
112+
113+
"""
114+
TODO: do we want to make it a public function?
115+
"""
116+
function _noncommoninds(partition::DataGraph)
117+
networks = [Vector{ITensor}(partition[v]) for v in vertices(partition)]
118+
network = vcat(networks...)
119+
return noncommoninds(network...)
120+
end
121+
122+
"""
123+
Given an input `partition`, contract redundent delta tensors of non-leaf vertices
124+
in `partition` without changing the tensor network value.
125+
`root` is the root of the dfs_tree that defines the leaves.
126+
Note: for each vertex `v` of `partition`, the number of non-delta tensors
127+
in `partition[v]` will not be changed.
128+
Note: only delta tensors of non-leaf vertices will be contracted.
129+
Note: this function assumes that all noncommoninds of the partition are in leaf partitions.
130+
"""
131+
function _contract_deltas_ignore_leaf_partitions(
132+
partition::DataGraph; root=first(vertices(partition))
133+
)
134+
partition = copy(partition)
135+
leaves = leaf_vertices(dfs_tree(partition, root))
136+
nonleaves = setdiff(vertices(partition), leaves)
137+
rootinds = _noncommoninds(subgraph(partition, nonleaves))
138+
# check rootinds are not noncommoninds of the partition
139+
@assert intersect(rootinds, _noncommoninds(partition)) == []
140+
nonleaves_tn = _contract_deltas(reduce(union, [partition[v] for v in nonleaves]))
141+
nondelta_vs = filter(v -> !is_delta(nonleaves_tn[v]), vertices(nonleaves_tn))
142+
for v in nonleaves
143+
partition[v] = subgraph(nonleaves_tn, intersect(nondelta_vs, vertices(partition[v])))
144+
end
145+
# Note: we also need to change inds in the leaves since they can be connected by deltas
146+
# in nonleaf vertices
147+
delta_vs = setdiff(vertices(nonleaves_tn), nondelta_vs)
148+
if delta_vs == []
149+
return partition
150+
end
151+
ds = _delta_inds_disjointsets(
152+
Vector{ITensor}(subgraph(nonleaves_tn, delta_vs)), Vector{Index}()
153+
)
154+
deltainds = [ds...]
155+
sim_deltainds = [find_root!(ds, ind) for ind in deltainds]
156+
for tn_v in leaves
157+
partition[tn_v] = map_data(
158+
t -> replaceinds(t, deltainds, sim_deltainds), partition[tn_v]; edges=[]
159+
)
160+
end
161+
return partition
162+
end

src/itensors.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,9 @@ trivial_space(x::Index) = _trivial_space(x)
4141
trivial_space(x::Vector{<:Index}) = _trivial_space(x)
4242
trivial_space(x::ITensor) = trivial_space(inds(x))
4343
trivial_space(x::Tuple{Vararg{Index}}) = trivial_space(first(x))
44+
45+
is_delta(it::ITensor) = is_delta(ITensors.tensor(it))
46+
is_delta(t::ITensors.Tensor) = false
47+
function is_delta(t::ITensors.NDTensors.UniformDiagTensor)
48+
return isone(ITensors.NDTensors.getdiagindex(t, 1))
49+
end

0 commit comments

Comments
 (0)