Skip to content

Additions needed for sweeping algorithms on trees #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c81a211
Sweeping algorithms for tree tensor networks: draft
Oct 31, 2022
b3ca106
Switch off numerical truncation by default in OpSum->TTNO converter
leburgel Nov 1, 2022
b389589
Merge branch 'indsnetwork_additions' into tree_sweeping
leburgel Nov 3, 2022
e00c345
Remove functionalities that were merged into NamedGraphs.jl
leburgel Nov 3, 2022
74dc679
Merge branch 'indsnetwork_additions' into tree_sweeping
leburgel Nov 3, 2022
f883124
Merge remote-tracking branch 'origin' into tree_sweeping
leburgel Nov 4, 2022
0eb9ef9
Remove temporary `expect` in favor or `expect.jl` merged from main
leburgel Nov 4, 2022
d337f7c
Use `convert_eltype` instead of `convert_leaf_eltype` where appropria…
leburgel Nov 28, 2022
77be11c
Change `linkdims` to return `NamedDimDataGraph`
leburgel Nov 28, 2022
d6b9633
Update
leburgel Dec 12, 2022
26ec817
Merge branch 'main' into tree_sweeping
leburgel Dec 12, 2022
40a9975
Move implementation to `loginner`, make `logdot` the alias.
leburgel Dec 21, 2022
a04d3cf
No more mutables; remove many in-place operations.
leburgel Dec 22, 2022
152db53
Remove in-place operations.
leburgel Jan 9, 2023
d490aed
Merge remote-tracking branch 'origin' into tree_sweeping
leburgel Jan 9, 2023
e9f0c7a
Remove unnecessary comment.
leburgel Jan 9, 2023
914b711
Merge branch 'main' into tree_sweeping
mtfishman Jan 9, 2023
6cb3e36
Fix imports.jl
mtfishman Jan 9, 2023
591cd3c
Update directory name from treetensornetwork to treetensornetworks
mtfishman Jan 9, 2023
105cec6
Fix some warnings
mtfishman Jan 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@ Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
SparseArrayKit = "a9a3c162-d163-4c15-8926-b8794fbefed2"
SplitApplyCombine = "03a91e81-4c3e-53e1-a0a4-9c0c8f19dd66"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"

Expand Down
15 changes: 14 additions & 1 deletion src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Dictionaries
using DocStringExtensions
using Graphs
using Graphs.SimpleGraphs # AbstractSimpleGraph
using IsApprox
using ITensors
using ITensors.ContractionSequenceOptimization
using ITensors.ITensorVisualizationCore
Expand All @@ -15,7 +16,9 @@ using Observers
using Printf
using Requires
using SimpleTraits
using SparseArrayKit
using SplitApplyCombine
using StaticArrays
using Suppressor
using TimerOutputs

Expand All @@ -27,6 +30,7 @@ using ITensors:
@timeit_debug,
AbstractMPS,
Algorithm,
OneITensor,
check_hascommoninds,
commontags,
orthocenter,
Expand Down Expand Up @@ -69,11 +73,20 @@ include("expect.jl")
include("models.jl")
include("tebd.jl")
include("itensornetwork.jl")
include("utility.jl")
include("specialitensornetworks.jl")
include("renameitensornetwork.jl")
include("boundarymps.jl")
include("beliefpropagation.jl")
include(joinpath("treetensornetworks", "treetensornetwork.jl"))
include(joinpath("treetensornetworks", "abstracttreetensornetwork.jl"))
# include(joinpath("treetensornetworks", "treetensornetwork.jl"))
include(joinpath("treetensornetworks", "ttns.jl"))
include(joinpath("treetensornetworks", "ttno.jl"))
include(joinpath("treetensornetworks", "opsum_to_ttno.jl"))
include(joinpath("treetensornetworks", "abstractprojttno.jl"))
include(joinpath("treetensornetworks", "projttno.jl"))
include(joinpath("treetensornetworks", "projttnosum.jl"))
include(joinpath("treetensornetworks", "projttno_apply.jl"))
# Compatibility of ITensor observer and Observers
# TODO: Delete this
include(joinpath("treetensornetworks", "solvers", "update_observer.jl"))
Expand Down
2 changes: 1 addition & 1 deletion src/abstractindsnetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ edge_data_type(::Type{<:AbstractIndsNetwork{V,I}}) where {V,I} = Vector{I}

function uniqueinds(is::AbstractIndsNetwork, edge::AbstractEdge)
inds = IndexSet(get(is, src(edge), Index[]))
for ei in setdiff(incident_edges(is, src(edge)...), [edge])
for ei in setdiff(incident_edges(is, src(edge)), [edge])
inds = unioninds(inds, get(is, ei, Index[]))
end
return inds
Expand Down
181 changes: 174 additions & 7 deletions src/abstractitensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@ end
# Iteration
#

# TODO: iteration

# TODO: different `map` functionalities as defined for ITensors.AbstractMPS

# TODO: broadcasting

function union(tn1::AbstractITensorNetwork, tn2::AbstractITensorNetwork; kwargs...)
tn = ITensorNetwork(union(data_graph(tn1), data_graph(tn2)); kwargs...)
# Add any new edges that are introduced during the union
Expand Down Expand Up @@ -104,6 +110,36 @@ end
# Convenience wrapper
itensors(tn::AbstractITensorNetwork) = Vector{ITensor}(tn)

#
# Promotion and conversion
#

function LinearAlgebra.promote_leaf_eltypes(tn::AbstractITensorNetwork)
return LinearAlgebra.promote_leaf_eltypes(itensors(tn))
end

function ITensors.promote_itensor_eltype(tn::AbstractITensorNetwork)
return LinearAlgebra.promote_leaf_eltypes(tn)
end

ITensors.scalartype(tn::AbstractITensorNetwork) = LinearAlgebra.promote_leaf_eltypes(tn)

# TODO: eltype(::AbstractITensorNetwork) (cannot behave the same as eltype(::ITensors.AbstractMPS))

# TODO: mimic ITensors.AbstractMPS implementation using map
function ITensors.convert_leaf_eltype(eltype::Type, tn::AbstractITensorNetwork)
tn = copy(tn)
vertex_data(tn) .= convert_eltype.(Ref(eltype), vertex_data(tn))
return tn
end

# TODO: mimic ITensors.AbstractMPS implementation using map
function NDTensors.convert_scalartype(eltype::Type{<:Number}, tn::AbstractITensorNetwork)
tn = copy(tn)
vertex_data(tn) .= ITensors.adapt.(Ref(eltype), vertex_data(tn))
return tn
end

#
# Conversion to Graphs
#
Expand Down Expand Up @@ -185,11 +221,13 @@ end
function replaceinds(tn::AbstractITensorNetwork, is_is′::Pair{<:IndsNetwork,<:IndsNetwork})
tn = copy(tn)
is, is′ = is_is′
# TODO: Check that `is` and `is′` have the same vertices and edges.
@assert underlying_graph(is) == underlying_graph(is′)
for v in vertices(is)
isassigned(is, v) || continue
setindex_preserve_graph!(tn, replaceinds(tn[v], is[v] => is′[v]), v)
end
for e in edges(is)
isassigned(is, e) || continue
for v in (src(e), dst(e))
setindex_preserve_graph!(tn, replaceinds(tn[v], is[e] => is′[e]), v)
end
Expand All @@ -208,7 +246,7 @@ const map_inds_label_functions = [
:setprime,
:noprime,
:replaceprime,
:swapprime,
# :swapprime, # TODO: add @test_broken as a reminder
:addtags,
:removetags,
:replacetags,
Expand All @@ -227,6 +265,24 @@ for f in map_inds_label_functions
function $f(n::Union{IndsNetwork,AbstractITensorNetwork}, args...; kwargs...)
return map_inds($f, n, args...; kwargs...)
end

function $f(
ffilter::typeof(linkinds),
n::Union{IndsNetwork,AbstractITensorNetwork},
args...;
kwargs...,
)
return map_inds($f, n, args...; sites=[], kwargs...)
end

function $f(
ffilter::typeof(siteinds),
n::Union{IndsNetwork,AbstractITensorNetwork},
args...;
kwargs...,
)
return map_inds($f, n, args...; links=[], kwargs...)
end
end
end

Expand Down Expand Up @@ -402,12 +458,19 @@ function factorize(
return factorize(tn, edgetype(tn)(edge); kwargs...)
end

# For ambiguity error
# For ambiguity error; TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _orthogonalize_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
tn = factorize(tn, edge; kwargs...)
# TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
return contract(tn, new_vertex => dst(edge))
# tn = factorize(tn, edge; kwargs...)
# # TODO: Implement as `only(common_neighbors(tn, src(edge), dst(edge)))`
# new_vertex = only(neighbors(tn, src(edge)) ∩ neighbors(tn, dst(edge)))
# return contract(tn, new_vertex => dst(edge))
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
X, Y = factorize(tn[src(edge)], left_inds; tags=ltags, ortho="left", kwargs...)
tn[src(edge)] = X
tn[dst(edge)] *= Y
return tn
end

function orthogonalize(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
Expand All @@ -429,6 +492,25 @@ function orthogonalize(ψ::AbstractITensorNetwork, source_vertex)
return ψ
end

# TODO: decide whether to use graph mutating methods when resulting graph is unchanged?
function _truncate_edge(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
tn = copy(tn)
left_inds = uniqueinds(tn, edge)
ltags = tags(tn, edge)
U, S, V = svd(tn[src(edge)], left_inds; lefttags=ltags, ortho="left", kwargs...)
tn[src(edge)] = U
tn[dst(edge)] *= (S * V)
return tn
end

function truncate(tn::AbstractITensorNetwork, edge::AbstractEdge; kwargs...)
return _truncate_edge(tn, edge; kwargs...)
end

function truncate(tn::AbstractITensorNetwork, edge::Pair; kwargs...)
return truncate(tn, edgetype(tn)(edge); kwargs...)
end

function Base.:*(c::Number, ψ::AbstractITensorNetwork)
v₁ = first(vertices(ψ))
cψ = copy(ψ)
Expand Down Expand Up @@ -572,6 +654,91 @@ function visualize(
return visualize(Vector{ITensor}(tn), args...; vertex_labels, kwargs...)
end

#
# Link dimensions
#

function maxlinkdim(tn::AbstractITensorNetwork)
md = 1
for e in edges(tn)
md = max(md, linkdim(tn, e))
end
return md
end

function linkdim(tn::AbstractITensorNetwork, edge::Pair)
return linkdim(tn, edgetype(tn)(edge))
end

function linkdim(tn::AbstractITensorNetwork{V}, edge::AbstractEdge{V}) where {V}
ls = linkinds(tn, edge)
return prod([isnothing(l) ? 1 : dim(l) for l in ls])
end

function linkdims(tn::AbstractITensorNetwork{V}) where {V}
ld = DataGraph{V,Any,Int}(copy(underlying_graph(tn)))
for e in edges(ld)
ld[e] = linkdim(tn, e)
end
return ld
end

#
# Common index checking
#

function hascommoninds(
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
) where {V}
for v in vertices(A)
!hascommoninds(siteinds(A, v), siteinds(B, v)) && return false
end
return true
end

function check_hascommoninds(
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
) where {V}
N = nv(A)
if nv(B) ≠ N
throw(
DimensionMismatch(
"$(typeof(A)) and $(typeof(B)) have mismatched number of vertices $N and $(nv(B))."
),
)
end
for v in vertices(A)
!hascommoninds(siteinds(A, v), siteinds(B, v)) && error(
"$(typeof(A)) A and $(typeof(B)) B must share site indices. On vertex $v, A has site indices $(siteinds(A, v)) while B has site indices $(siteinds(B, v)).",
)
end
return nothing
end

function hassameinds(
::typeof(siteinds), A::AbstractITensorNetwork{V}, B::AbstractITensorNetwork{V}
) where {V}
nv(A) ≠ nv(B) && return false
for v in vertices(A)
!ITensors.hassameinds(siteinds(A, v), siteinds(B, v)) && return false
end
return true
end

#
# Site combiners
#

# TODO: will be broken, fix this
function site_combiners(tn::AbstractITensorNetwork{V}) where {V}
Cs = DataGraph{V,ITensor}(copy(underlying_graph(tn)))
for v in vertices(tn)
s = siteinds(tn, v)
Cs[v] = combiner(s; tags=commontags(s))
end
return Cs
end

## # TODO: should this make sure that internal indices
## # don't clash?
## function hvncat(
Expand Down
7 changes: 5 additions & 2 deletions src/expect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@ function expect(
maxdim=nothing,
ortho=false,
sequence=nothing,
sites=vertices(ψ),
)
s = siteinds(ψ)
res = Dictionary(vertices(ψ), Vector{Float64}(undef, nv(ψ)))
ElT = promote_itensor_eltype(ψ)
# ElT = ishermitian(ITensors.op(op, s[sites[1]])) ? real(ElT) : ElT
res = Dictionary(sites, Vector{ElT}(undef, length(sites)))
if isnothing(sequence)
sequence = contraction_sequence(inner_network(ψ, ψ; flatten=true))
end
normψ² = norm_sqr(ψ; sequence)
for v in vertices(ψ)
for v in sites
O = ITensor(Op(op, v), s)
Oψ = apply(O, ψ; cutoff, maxdim, ortho)
res[v] = contract_inner(ψ, Oψ; sequence) / normψ²
Expand Down
Loading