Skip to content

Commit

Permalink
Update approx_itensornetwork.jl caching implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
LinjianMa committed Mar 15, 2023
1 parent 20b7383 commit fc5b85a
Showing 1 changed file with 70 additions and 130 deletions.
200 changes: 70 additions & 130 deletions src/approx_itensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
"""
The struct is used to store cached density matrices in `approx_binary_tree_itensornetwork`.
tensor: the cached symmetric density matric tensor
root: the root vertex of which the density matrix tensor is computed
children: the children vertices of the root where the density matrix tensor is computed
The struct contains cached density matrices and cached partial density matrices
for each edge / set of edges in the tensor network.
Example:
Density matrix example:
Consider a tensor network below,
1
/\
Expand All @@ -14,9 +12,7 @@ Example:
/| /\
4 5 7 8
/ | | \
The density matrix for the root 3, children [4, 5] squares the subgraph
with vertices 3, 4, 5
The density matrix for the edge `NamedEdge(2, 3)` squares the subgraph with vertices 3, 4, 5
|
3
/|
Expand All @@ -26,8 +22,7 @@ Example:
|/
3
|
The density matrix for the root 3, children [2, 4] squares the subgraph
The density matrix for the edge `NamedEdge(5, 3)` squares the subgraph
with vertices 1, 2, 3, 4, 6, 7, 8, 9
1
/\
Expand All @@ -46,8 +41,7 @@ Example:
9 /
|/
1
The density matrix for the root 3, children [2, 5] squares the subgraph
The density matrix for the edge `NamedEdge(4, 3)` squares the subgraph
with vertices 1, 2, 3, 5, 6, 7, 8, 9
1
/\
Expand All @@ -66,20 +60,8 @@ Example:
9 /
|/
1
"""
struct _DensityMatrix
tensor::ITensor
root::Any
children::Vector
end

"""
The struct is used to store cached partial density matrices in `approx_binary_tree_itensornetwork`.
tensor: the cached partial density matric tensor
root: the root vertex of which the partial density matrix tensor is computed
child: the child vertex of the root where the density matrix tensor is computed
Example:
Partial density matrix example:
Consider a tensor network below,
1
/\
Expand All @@ -89,16 +71,14 @@ Example:
/| /\
4 5 7 8
/ | | \
The partial density matrix for the root 3, child 4 squares the subgraph
with vertices 4, and contract with the tensor 3
The partial density matrix for the Edge set `Set([NamedEdge(2, 3), NamedEdge(5, 3)])`
squares the subgraph with vertices 4, and contract with the tensor 3
|
3
/
4 - 4 -
The partial density matrix for the root 3, child 2 squares the subgraph
with vertices 1, 2, 6, 7, 8, 9, and contract with the tensor 3
The partial density matrix for the Edge set `Set([NamedEdge(4, 3), NamedEdge(5, 3)])`
squares the subgraph with vertices 1, 2, 6, 7, 8, 9, and contract with the tensor 3
1
/\
/ 2
Expand All @@ -116,40 +96,22 @@ Example:
9 /
|/
1
The density matrix for the root 3, children 5 squares the subgraph
with vertices 5. and contract with the tensor 3
The density matrix for the Edge set `Set([NamedEdge(4, 3), NamedEdge(2, 3)])`
squares the subgraph with vertices 5. and contract with the tensor 3
|
3
/
5 - 5 -
"""
struct _PartialDensityMatrix
tensor::ITensor
root::Any
child::Any
end

"""
The struct contains cached density matrices and cached partial density matrices
for each vertex in the tensor network.
"""
struct _DensityMatrixAlgCaches
v_to_cdm::Dict{Any,_DensityMatrix}
v_to_cpdms::Dict{Any,Vector{_PartialDensityMatrix}}
e_to_dm::Dict{NamedEdge,ITensor}
es_to_pdm::Dict{Set{NamedEdge},ITensor}
end

function _DensityMatrixAlgCaches()
v_to_cdm = Dict{Any,_DensityMatrix}()
v_to_cpdms = Dict{Any,Vector{_PartialDensityMatrix}}()
return _DensityMatrixAlgCaches(v_to_cdm, v_to_cpdms)
end

"""
Remove cached partial density matrices from `cpdms` whose child is in `children`
"""
function _remove_cpdms(cpdms::Vector, children)
return filter(pdm -> !(pdm.child in children), cpdms)
e_to_dm = Dict{NamedEdge,ITensor}()
es_to_pdm = Dict{Set{NamedEdge},ITensor}()
return _DensityMatrixAlgCaches(e_to_dm, es_to_pdm)
end

"""
Expand Down Expand Up @@ -237,87 +199,55 @@ function _sim(partial_dm_tensor::ITensor, inds_to_siminds)
end

"""
Return the partial density matrix whose root is `v` and root child is `child_v`.
If the tensor is in `partial_dms`, just return the tensor without contraction.
"""
function _get_pdm(
partial_dms::Vector{_PartialDensityMatrix},
v,
child_v,
child_dm_tensor,
network;
contraction_sequence_alg,
contraction_sequence_kwargs,
)
for partial_dm in partial_dms
if partial_dm.child == child_v
return partial_dm
end
end
tensor = _optcontract(
[child_dm_tensor, network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
return _PartialDensityMatrix(tensor, v, child_v)
end

"""
Update `caches.v_to_cdm[v]` and `caches.v_to_cpdms[v]`.
Update `caches.e_to_dm[e]` and `caches.es_to_pdm[es]`.
caches: the caches of the density matrix algorithm.
v: the density matrix root
children: the children vertices of `v` in the dfs_tree
root: the root vertex of the truncation algorithm
network: the tensor network at vertex `v`
edge: the edge defining the density matrix
children: the children vertices of `dst(edge)` in the dfs_tree
network: the tensor network at vertex `dst(edge)`
inds_to_sim: a dict mapping inds to sim inds
"""
function _update!(
caches::_DensityMatrixAlgCaches,
v::Any,
edge::NamedEdge,
children::Vector,
root::Any,
network::Vector{ITensor},
inds_to_sim;
contraction_sequence_alg,
contraction_sequence_kwargs,
)
if haskey(caches.v_to_cdm, v) && caches.v_to_cdm[v].children == children && v != root
@assert haskey(caches.v_to_cdm, v)
v = dst(edge)
if haskey(caches.e_to_dm, edge)
return nothing
end
child_to_dm = [c => caches.v_to_cdm[c].tensor for c in children]
if !haskey(caches.v_to_cpdms, v)
caches.v_to_cpdms[v] = []
child_to_dm = [c => caches.e_to_dm[NamedEdge(v, c)] for c in children]
pdms = []
for (child_v, dm_tensor) in child_to_dm
es = [NamedEdge(src_v, v) for src_v in setdiff(children, child_v)]
es = Set(vcat(es, [edge]))
if !haskey(caches.es_to_pdm, es)
caches.es_to_pdm[es] = _optcontract(
[dm_tensor, network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
end
push!(pdms, caches.es_to_pdm[es])
end
cpdms = [
_get_pdm(
caches.v_to_cpdms[v],
v,
child_v,
dm_tensor,
network;
contraction_sequence_alg,
contraction_sequence_kwargs,
) for (child_v, dm_tensor) in child_to_dm
]
if length(cpdms) == 0
if length(pdms) == 0
sim_network = map(x -> replaceinds(x, inds_to_sim), network)
density_matrix = _optcontract(
[network..., sim_network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
elseif length(cpdms) == 1
elseif length(pdms) == 1
sim_network = map(x -> replaceinds(x, inds_to_sim), network)
density_matrix = _optcontract(
[cpdms[1].tensor, sim_network...];
contraction_sequence_alg,
contraction_sequence_kwargs,
[pdms[1], sim_network...]; contraction_sequence_alg, contraction_sequence_kwargs
)
else
simtensor = _sim(cpdms[2].tensor, inds_to_sim)
simtensor = _sim(pdms[2], inds_to_sim)
density_matrix = _optcontract(
[cpdms[1].tensor, simtensor]; contraction_sequence_alg, contraction_sequence_kwargs
[pdms[1], simtensor]; contraction_sequence_alg, contraction_sequence_kwargs
)
end
caches.v_to_cdm[v] = _DensityMatrix(density_matrix, v, children)
caches.v_to_cpdms[v] = cpdms
caches.e_to_dm[edge] = density_matrix
return nothing
end

Expand Down Expand Up @@ -365,17 +295,16 @@ function _rem_vertex!(
network = Vector{ITensor}(alg_graph.partition[v])
_update!(
caches,
v,
NamedEdge(parent_vertex(dm_dfs_tree, v), v),
children,
root,
Vector{ITensor}(network),
inds_to_sim;
contraction_sequence_alg,
contraction_sequence_kwargs,
)
end
U = _get_low_rank_projector(
caches.v_to_cdm[root].tensor,
caches.e_to_dm[NamedEdge(nothing, root)],
collect(values(outinds_root_to_sim)),
collect(keys(outinds_root_to_sim));
cutoff,
Expand All @@ -393,22 +322,33 @@ function _rem_vertex!(
)
rem_vertex!(alg_graph.partition, root)
rem_vertex!(alg_graph.out_tree, root)
# update v_to_cpdms[new_root]
delete!(caches.v_to_cpdms, root)
# update es_to_pdm
truncate_dfs_tree = dfs_tree(alg_graph.out_tree, alg_graph.root)
caches.v_to_cpdms[new_root] = _remove_cpdms(
caches.v_to_cpdms[new_root], child_vertices(truncate_dfs_tree, new_root)
)
@assert length(caches.v_to_cpdms[new_root]) <= 1
caches.v_to_cpdms[new_root] = [
_PartialDensityMatrix(
_optcontract(
[cpdm.tensor, root_tensor]; contraction_sequence_alg, contraction_sequence_kwargs
),
new_root,
cpdm.child,
) for cpdm in caches.v_to_cpdms[new_root]
]
for es in keys(caches.es_to_pdm)
if dst(first(es)) == root
delete!(caches.es_to_pdm, es)
elseif dst(first(es)) == new_root
parent_edge = NamedEdge(parent_vertex(truncate_dfs_tree, new_root), new_root)
edge_to_remove = NamedEdge(root, new_root)
if intersect(es, [parent_edge]) == []
new_es = setdiff(es, [edge_to_remove])
caches.es_to_pdm[new_es] = _optcontract(
[caches.es_to_pdm[es], root_tensor];
contraction_sequence_alg,
contraction_sequence_kwargs,
)
end
# Remove old caches since they won't be used anymore,
# and removing them saves later contraction costs.
delete!(caches.es_to_pdm, es)
end
end
# update e_to_dm
for edge in keys(caches.e_to_dm)
if dst(edge) in [root, new_root]
delete!(caches.e_to_dm, edge)
end
end
return U
end

Expand Down

0 comments on commit fc5b85a

Please sign in to comment.