From 472ceb6c3e1dcc2d41e5fa0729b0774fb43d15d1 Mon Sep 17 00:00:00 2001 From: Linjian Ma Date: Mon, 22 May 2023 13:47:42 -0500 Subject: [PATCH] Fix a small bug related to caching in approx_itensornetwork (#89) --- src/approx_itensornetwork.jl | 41 ++++++++++--------- test/test_binary_tree_partition.jl | 65 ++++++++++++++++++++++++------ 2 files changed, 74 insertions(+), 32 deletions(-) diff --git a/src/approx_itensornetwork.jl b/src/approx_itensornetwork.jl index 11f63bd6..5cccbc66 100644 --- a/src/approx_itensornetwork.jl +++ b/src/approx_itensornetwork.jl @@ -327,30 +327,31 @@ function _rem_vertex!( rem_vertex!(alg_graph.out_tree, root) # update es_to_pdm truncate_dfs_tree = dfs_tree(alg_graph.out_tree, alg_graph.root) - for es in keys(caches.es_to_pdm) - if dst(first(es)) == root - delete!(caches.es_to_pdm, es) - elseif dst(first(es)) == new_root - parent_edge = NamedEdge(parent_vertex(truncate_dfs_tree, new_root), new_root) - edge_to_remove = NamedEdge(root, new_root) - if intersect(es, [parent_edge]) == [] - new_es = setdiff(es, [edge_to_remove]) - caches.es_to_pdm[new_es] = _optcontract( - [caches.es_to_pdm[es], root_tensor]; - contraction_sequence_alg, - contraction_sequence_kwargs, - ) + for es in filter(es -> dst(first(es)) == root, keys(caches.es_to_pdm)) + delete!(caches.es_to_pdm, es) + end + for es in filter(es -> dst(first(es)) == new_root, keys(caches.es_to_pdm)) + parent_edge = NamedEdge(parent_vertex(truncate_dfs_tree, new_root), new_root) + edge_to_remove = NamedEdge(root, new_root) + if intersect(es, Set([parent_edge])) == Set() + new_es = setdiff(es, [edge_to_remove]) + if new_es == Set() + new_es = Set([NamedEdge(nothing, new_root)]) end - # Remove old caches since they won't be used anymore, - # and removing them saves later contraction costs. - delete!(caches.es_to_pdm, es) + @assert length(new_es) >= 1 + caches.es_to_pdm[new_es] = _optcontract( + [caches.es_to_pdm[es], root_tensor]; + contraction_sequence_alg, + contraction_sequence_kwargs, + ) end + # Remove old caches since they won't be used anymore, + # and removing them saves later contraction costs. + delete!(caches.es_to_pdm, es) end # update e_to_dm - for edge in keys(caches.e_to_dm) - if dst(edge) in [root, new_root] - delete!(caches.e_to_dm, edge) - end + for edge in filter(e -> dst(e) in [root, new_root], keys(caches.e_to_dm)) + delete!(caches.e_to_dm, edge) end return U end diff --git a/test/test_binary_tree_partition.jl b/test/test_binary_tree_partition.jl index 966920b0..f00ef278 100644 --- a/test/test_binary_tree_partition.jl +++ b/test/test_binary_tree_partition.jl @@ -1,10 +1,14 @@ using ITensors, OMEinsumContractionOrders +using Graphs, NamedGraphs +using ITensors: contract using ITensorNetworks: _root, _mps_partition_inds_order, _mincut_partitions, _is_rooted_directed_binary_tree, - _contract_deltas_ignore_leaf_partitions + _contract_deltas_ignore_leaf_partitions, + _rem_vertex!, + _DensityMartrixAlgGraph @testset "test mincut functions on top of MPS" begin i = Index(2, "i") @@ -83,17 +87,54 @@ end out2 = contract(network2...) @test isapprox(out1, out2) # test approx_itensornetwork (here we call `contract` to test the interface) - approx_tn, lognorm = contract( - tn; - alg="density_matrix", - output_structure=binary_tree_structure, - contraction_sequence_alg="sa_bipartite", + for structure in [path_graph_structure, binary_tree_structure] + approx_tn, lognorm = contract( + tn; + alg="density_matrix", + output_structure=structure, + contraction_sequence_alg="sa_bipartite", + ) + network3 = Vector{ITensor}(approx_tn) + out3 = contract(network3...) * exp(lognorm) + i1 = noncommoninds(network...) + i3 = noncommoninds(network3...) + @test (length(i1) == length(i3)) + @test isapprox(out1, out3) + end + end +end + +@testset "test caching in approx_itensornetwork" begin + i = Index(2, "i") + j = Index(2, "j") + k = Index(2, "k") + l = Index(2, "l") + m = Index(2, "m") + T = randomITensor(i, j, k, l, m) + M = MPS(T, (i, j, k, l, m); cutoff=1e-5, maxdim=5) + tn = ITensorNetwork(M[:]) + out_tree = path_graph_structure(tn) + input_partition = partition(tn, out_tree; alg="mincut_recursive_bisection") + underlying_tree = underlying_graph(input_partition) + # Change type of each partition[v] since they will be updated + # with potential data type chage. + p = DataGraph() + for v in vertices(input_partition) + add_vertex!(p, v) + p[v] = ITensorNetwork{Any}(input_partition[v]) + end + alg_graph = _DensityMartrixAlgGraph(p, underlying_tree, _root(out_tree)) + path = post_order_dfs_vertices(underlying_tree, _root(out_tree)) + for v in path[1:2] + _rem_vertex!( + alg_graph, + v; + cutoff=1e-15, + maxdim=10000, + contraction_sequence_alg="optimal", + contraction_sequence_kwargs=(;), ) - network3 = Vector{ITensor}(approx_tn) - out3 = contract(network3...) * exp(lognorm) - i1 = noncommoninds(network...) - i3 = noncommoninds(network3...) - @test (length(i1) == length(i3)) - @test isapprox(out1, out3) end + # Check that a specific density matrix info has been cached + @test haskey(alg_graph.caches.es_to_pdm, Set([NamedEdge(nothing, path[3])])) end