Skip to content

Commit

Permalink
Let approx_itensornetwork support complex tensor network (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
LinjianMa authored Apr 7, 2023
1 parent 756d09b commit bb36a6a
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
13 changes: 8 additions & 5 deletions src/approx_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,9 @@ end
function _get_low_rank_projector(tensor, inds1, inds2; cutoff, maxdim)
@assert length(inds(tensor)) <= 4
@timeit_debug ITensors.timer "[approx_binary_tree_itensornetwork]: eigen" begin
diag, U = eigen(tensor, inds1, inds2; cutoff=cutoff, maxdim=maxdim, ishermitian=true)
F = eigen(tensor, inds1, inds2; cutoff=cutoff, maxdim=maxdim, ishermitian=true)
end
return U
return F.Vt
end

"""
Expand Down Expand Up @@ -233,16 +233,19 @@ function _update!(
end
if length(pdms) == 0
sim_network = map(x -> replaceinds(x, inds_to_sim), network)
sim_network = map(dag, sim_network)
density_matrix = _optcontract(
[network..., sim_network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
elseif length(pdms) == 1
sim_network = map(x -> replaceinds(x, inds_to_sim), network)
sim_network = map(dag, sim_network)
density_matrix = _optcontract(
[pdms[1], sim_network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
else
simtensor = _sim(pdms[2], inds_to_sim)
simtensor = dag(simtensor)
density_matrix = _optcontract(
[pdms[1], simtensor]; contraction_sequence_alg, contraction_sequence_kwargs
)
Expand Down Expand Up @@ -305,14 +308,14 @@ function _rem_vertex!(
end
U = _get_low_rank_projector(
caches.e_to_dm[NamedEdge(nothing, root)],
collect(values(outinds_root_to_sim)),
collect(keys(outinds_root_to_sim));
collect(keys(outinds_root_to_sim)),
collect(values(outinds_root_to_sim));
cutoff,
maxdim,
)
# update partition and out_tree
root_tensor = _optcontract(
[Vector{ITensor}(alg_graph.partition[root])..., U];
[Vector{ITensor}(alg_graph.partition[root])..., dag(U)];
contraction_sequence_alg,
contraction_sequence_kwargs,
)
Expand Down
49 changes: 27 additions & 22 deletions test/test_binary_tree_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,31 @@ end
k = Index(2, "k")
l = Index(2, "l")
m = Index(2, "m")
T = randomITensor(i, j, k, l, m)
M = MPS(T, (i, j, k, l, m); cutoff=1e-5, maxdim=5)
network = M[:]
out1 = contract(network...)
tn = ITensorNetwork(network)
inds_btree = binary_tree_structure(tn)
par = partition(tn, inds_btree; alg="mincut_recursive_bisection")
par = _contract_deltas_ignore_leaf_partitions(par; root=_root(inds_btree))
networks = [Vector{ITensor}(par[v]) for v in vertices(par)]
network2 = vcat(networks...)
out2 = contract(network2...)
@test isapprox(out1, out2)
# test approx_itensornetwork
approx_tn, lognorm = approx_itensornetwork(
tn, binary_tree_structure; alg="density_matrix", contraction_sequence_alg="sa_bipartite"
)
network3 = Vector{ITensor}(approx_tn)
out3 = contract(network3...) * exp(lognorm)
i1 = noncommoninds(network...)
i3 = noncommoninds(network3...)
@test (length(i1) == length(i3))
@test isapprox(out1, out3)
for dtype in [Float64, ComplexF64]
T = randomITensor(dtype, i, j, k, l, m)
M = MPS(T, (i, j, k, l, m); cutoff=1e-5, maxdim=5)
network = M[:]
out1 = contract(network...)
tn = ITensorNetwork(network)
inds_btree = binary_tree_structure(tn)
par = partition(tn, inds_btree; alg="mincut_recursive_bisection")
par = _contract_deltas_ignore_leaf_partitions(par; root=_root(inds_btree))
networks = [Vector{ITensor}(par[v]) for v in vertices(par)]
network2 = vcat(networks...)
out2 = contract(network2...)
@test isapprox(out1, out2)
# test approx_itensornetwork
approx_tn, lognorm = approx_itensornetwork(
tn,
binary_tree_structure;
alg="density_matrix",
contraction_sequence_alg="sa_bipartite",
)
network3 = Vector{ITensor}(approx_tn)
out3 = contract(network3...) * exp(lognorm)
i1 = noncommoninds(network...)
i3 = noncommoninds(network3...)
@test (length(i1) == length(i3))
@test isapprox(out1, out3)
end
end

0 comments on commit bb36a6a

Please sign in to comment.