Skip to content

Commit

Permalink
Fix a small bug related to caching in approx_itensornetwork (#89)
Browse files Browse the repository at this point in the history
  • Loading branch information
LinjianMa authored May 22, 2023
1 parent 176ed00 commit 472ceb6
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 32 deletions.
41 changes: 21 additions & 20 deletions src/approx_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -327,30 +327,31 @@ function _rem_vertex!(
rem_vertex!(alg_graph.out_tree, root)
# update es_to_pdm
truncate_dfs_tree = dfs_tree(alg_graph.out_tree, alg_graph.root)
for es in keys(caches.es_to_pdm)
if dst(first(es)) == root
delete!(caches.es_to_pdm, es)
elseif dst(first(es)) == new_root
parent_edge = NamedEdge(parent_vertex(truncate_dfs_tree, new_root), new_root)
edge_to_remove = NamedEdge(root, new_root)
if intersect(es, [parent_edge]) == []
new_es = setdiff(es, [edge_to_remove])
caches.es_to_pdm[new_es] = _optcontract(
[caches.es_to_pdm[es], root_tensor];
contraction_sequence_alg,
contraction_sequence_kwargs,
)
for es in filter(es -> dst(first(es)) == root, keys(caches.es_to_pdm))
delete!(caches.es_to_pdm, es)
end
for es in filter(es -> dst(first(es)) == new_root, keys(caches.es_to_pdm))
parent_edge = NamedEdge(parent_vertex(truncate_dfs_tree, new_root), new_root)
edge_to_remove = NamedEdge(root, new_root)
if intersect(es, Set([parent_edge])) == Set()
new_es = setdiff(es, [edge_to_remove])
if new_es == Set()
new_es = Set([NamedEdge(nothing, new_root)])
end
# Remove old caches since they won't be used anymore,
# and removing them saves later contraction costs.
delete!(caches.es_to_pdm, es)
@assert length(new_es) >= 1
caches.es_to_pdm[new_es] = _optcontract(
[caches.es_to_pdm[es], root_tensor];
contraction_sequence_alg,
contraction_sequence_kwargs,
)
end
# Remove old caches since they won't be used anymore,
# and removing them saves later contraction costs.
delete!(caches.es_to_pdm, es)
end
# update e_to_dm
for edge in keys(caches.e_to_dm)
if dst(edge) in [root, new_root]
delete!(caches.e_to_dm, edge)
end
for edge in filter(e -> dst(e) in [root, new_root], keys(caches.e_to_dm))
delete!(caches.e_to_dm, edge)
end
return U
end
Expand Down
65 changes: 53 additions & 12 deletions test/test_binary_tree_partition.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
using ITensors, OMEinsumContractionOrders
using Graphs, NamedGraphs
using ITensors: contract
using ITensorNetworks:
_root,
_mps_partition_inds_order,
_mincut_partitions,
_is_rooted_directed_binary_tree,
_contract_deltas_ignore_leaf_partitions
_contract_deltas_ignore_leaf_partitions,
_rem_vertex!,
_DensityMartrixAlgGraph

@testset "test mincut functions on top of MPS" begin
i = Index(2, "i")
Expand Down Expand Up @@ -83,17 +87,54 @@ end
out2 = contract(network2...)
@test isapprox(out1, out2)
# test approx_itensornetwork (here we call `contract` to test the interface)
approx_tn, lognorm = contract(
tn;
alg="density_matrix",
output_structure=binary_tree_structure,
contraction_sequence_alg="sa_bipartite",
for structure in [path_graph_structure, binary_tree_structure]
approx_tn, lognorm = contract(
tn;
alg="density_matrix",
output_structure=structure,
contraction_sequence_alg="sa_bipartite",
)
network3 = Vector{ITensor}(approx_tn)
out3 = contract(network3...) * exp(lognorm)
i1 = noncommoninds(network...)
i3 = noncommoninds(network3...)
@test (length(i1) == length(i3))
@test isapprox(out1, out3)
end
end
end

@testset "test caching in approx_itensornetwork" begin
i = Index(2, "i")
j = Index(2, "j")
k = Index(2, "k")
l = Index(2, "l")
m = Index(2, "m")
T = randomITensor(i, j, k, l, m)
M = MPS(T, (i, j, k, l, m); cutoff=1e-5, maxdim=5)
tn = ITensorNetwork(M[:])
out_tree = path_graph_structure(tn)
input_partition = partition(tn, out_tree; alg="mincut_recursive_bisection")
underlying_tree = underlying_graph(input_partition)
# Change type of each partition[v] since they will be updated
# with potential data type chage.
p = DataGraph()
for v in vertices(input_partition)
add_vertex!(p, v)
p[v] = ITensorNetwork{Any}(input_partition[v])
end
alg_graph = _DensityMartrixAlgGraph(p, underlying_tree, _root(out_tree))
path = post_order_dfs_vertices(underlying_tree, _root(out_tree))
for v in path[1:2]
_rem_vertex!(
alg_graph,
v;
cutoff=1e-15,
maxdim=10000,
contraction_sequence_alg="optimal",
contraction_sequence_kwargs=(;),
)
network3 = Vector{ITensor}(approx_tn)
out3 = contract(network3...) * exp(lognorm)
i1 = noncommoninds(network...)
i3 = noncommoninds(network3...)
@test (length(i1) == length(i3))
@test isapprox(out1, out3)
end
# Check that a specific density matrix info has been cached
@test haskey(alg_graph.caches.es_to_pdm, Set([NamedEdge(nothing, path[3])]))
end

0 comments on commit 472ceb6

Please sign in to comment.