-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] Add approximation tensor network contraction subpackage #11
Conversation
end | ||
end | ||
|
||
# ctree: contraction tree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nested vector if itensors, e.g. [[A,B,C,], [D,E]]
|
||
# ctree: contraction tree | ||
# tn: vector of tensors representing a tensor network | ||
# tn_tree: a dict maps each index tree in the tn to a tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
used in the caching
# contract_ig: the index group to be contracted next | ||
# ig_tree: an index group with a tree hierarchy | ||
function approximate_contract(ctree::Vector; kwargs...) | ||
tn_leaves = get_leaves(ctree) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tn_leaves
are all the groups
# ig_tree: an index group with a tree hierarchy | ||
function approximate_contract(ctree::Vector; kwargs...) | ||
tn_leaves = get_leaves(ctree) | ||
ctrees = topo_sort(ctree; leaves=tn_leaves) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ctrees
: A list of contraction tree, it defines the contraction path
) | ||
end | ||
for c in ctrees | ||
if ctree_to_igs[c] == [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
772-802 are to get caching information
…ntract, remove project_boundary
|
||
# merge two trees | ||
# new tree: | ||
# s | ||
# / \ | ||
# t1 t2 | ||
function merge_tree(t1::Vector, t2::Vector; append=false) | ||
if t2 == [] | ||
return t1 | ||
end | ||
if t1 == [] | ||
return t2 | ||
end | ||
if isleaf(t1) && isleaf(t2) | ||
return [t1, t2] | ||
end | ||
if isleaf(t1) | ||
return append ? [t1, t2...] : [t1, t2] | ||
end | ||
if isleaf(t2) | ||
return append ? [t1..., t2] : [t1, t2] | ||
end | ||
return append ? [t1..., t2...] : [t1, t2] | ||
end | ||
|
||
function isleaf(tree::Vector) | ||
if tree == [] | ||
@info "tree is empty" | ||
return false | ||
end | ||
if all(v -> !(v isa Vector), tree) | ||
return true | ||
end | ||
return false | ||
end | ||
|
||
# get the subtree of tree that is in the subset | ||
# example: | ||
# subtree([[1, 2], [3, 4]], [1, 3]) = ([[1], [3]]) | ||
function subtree(tree::Vector, subset::Union{Vector,Tuple}) | ||
if tree == [] | ||
return [] | ||
end | ||
if isleaf(tree) | ||
return intersect(tree, subset) | ||
end | ||
tree = [subtree(i, subset) for i in tree] | ||
tree = filter(t -> t != [], tree) | ||
if length(tree) == 1 && tree[1] isa Vector | ||
return tree[1] | ||
end | ||
return tree | ||
end | ||
|
||
# vectorize a tree | ||
# example: [[1,2], [3,4]] = [1, 2, 3, 4] | ||
function vectorize(tree) | ||
@assert tree != [] | ||
if !(tree isa Vector) | ||
return [tree] | ||
end | ||
return mapreduce(vectorize, vcat, tree) | ||
end | ||
|
||
# example: [[[1,2], [3,4]], [[5,6], [7,8]]] = [[1,2], [3,4], [5,6], [7,8]] | ||
function get_leaves(tree::Vector) | ||
if !(tree isa Vector{<:Vector}) | ||
return [tree] | ||
end | ||
return mapreduce(get_leaves, vcat, tree) | ||
end | ||
|
||
function line_to_tree(line::Vector) | ||
if length(line) == 1 && line[1] isa Vector | ||
return line[1] | ||
end | ||
if length(line) <= 2 | ||
return line | ||
end | ||
return [line_to_tree(line[1:(end - 1)]), line[end]] | ||
end | ||
|
||
function topo_sort(tn; type=Vector, leaves=[]) | ||
@timeit timer "topo_sort" begin | ||
topo_order = [] | ||
topo_sort_dfs!(tn, topo_order, leaves, type) | ||
return topo_order | ||
end | ||
end | ||
|
||
function topo_sort_dfs!(tn, topo_order, leaves, type) | ||
#Post-order DFS | ||
if (tn in leaves) || !(tn isa type) | ||
return nothing | ||
end | ||
for subtn in tn | ||
topo_sort_dfs!(subtn, topo_order, leaves, type) | ||
end | ||
return append!(topo_order, [tn]) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can any of this make use of functionality from AbstractTrees.jl? It would be nice to use existing packages if possible so we don't have to maintain our own versions of existing functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion, I will keep this in mind when refactoring the code
end | ||
|
||
# Note that the children ordering matters here. | ||
mutable struct IndexAdjacencyTree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LinjianMa could you explain what this type is for?
Will just close this one since updates are merged gradually in other PRs. |
Still need to get some more tests working, but most of the parts are done.
Package transferred from https://github.com/ITensor/ITensorNetworkAD.jl