Skip to content

Commit 1bd3e44

Browse files
authored
Define conj for AbstractITensorNetwork and @preserve_graph macro (#185)
* Define conj for AbstractITensorNetwork and @preserve_graph macro * Bump to v0.11.11
1 parent de7d4dd commit 1bd3e44

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ITensorNetworks"
22
uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7"
33
authors = ["Matthew Fishman <mfishman@flatironinstitute.org>, Joseph Tindall <jtindall@flatironinstitute.org> and contributors"]
4-
version = "0.11.10"
4+
version = "0.11.11"
55

66
[deps]
77
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
@@ -19,6 +19,7 @@ IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7"
1919
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
2020
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
2121
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
22+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
2223
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
2324
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
2425
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
@@ -63,6 +64,7 @@ ITensors = "0.6.8"
6364
IsApprox = "0.1"
6465
IterTools = "1.4.0"
6566
KrylovKit = "0.6, 0.7"
67+
MacroTools = "0.5"
6668
NDTensors = "0.3"
6769
NamedGraphs = "0.6.0"
6870
OMEinsumContractionOrders = "0.8.3"

src/abstractitensornetwork.jl

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ using ITensors:
3939
using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
4040
using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
4141
using LinearAlgebra: LinearAlgebra, factorize
42+
using MacroTools: @capture
4243
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
4344
using NamedGraphs.GraphsExtensions:
4445
, directed_graph, incident_edges, rename_vertices, vertextype
@@ -138,6 +139,30 @@ function setindex_preserve_graph!(tn::AbstractITensorNetwork, value, vertex)
138139
return tn
139140
end
140141

142+
# TODO: Move to `BaseExtensions` module.
143+
function is_setindex!_expr(expr::Expr)
144+
return is_assignment_expr(expr) && is_getindex_expr(first(expr.args))
145+
end
146+
is_setindex!_expr(x) = false
147+
is_getindex_expr(expr::Expr) = (expr.head === :ref)
148+
is_getindex_expr(x) = false
149+
is_assignment_expr(expr::Expr) = (expr.head === :(=))
150+
is_assignment_expr(expr) = false
151+
152+
# TODO: Define this in terms of a function mapping
153+
# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
154+
# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
155+
# Also allow annotating codeblocks like `@views`.
156+
macro preserve_graph(expr)
157+
if !is_setindex!_expr(expr)
158+
error(
159+
"preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)",
160+
)
161+
end
162+
@capture(expr, array_[indices__] = value_)
163+
return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...)))
164+
end
165+
141166
function ITensors.hascommoninds(tn::AbstractITensorNetwork, edge::Pair)
142167
return hascommoninds(tn, edgetype(tn)(edge))
143168
end
@@ -148,7 +173,7 @@ end
148173

149174
function Base.setindex!(tn::AbstractITensorNetwork, value, v)
150175
# v = to_vertex(tn, index...)
151-
setindex_preserve_graph!(tn, value, v)
176+
@preserve_graph tn[v] = value
152177
for edge in incident_edges(tn, v)
153178
rem_edge!(tn, edge)
154179
end
@@ -297,12 +322,12 @@ function ITensors.replaceinds(
297322
@assert underlying_graph(is) == underlying_graph(is′)
298323
for v in vertices(is)
299324
isassigned(is, v) || continue
300-
setindex_preserve_graph!(tn, replaceinds(tn[v], is[v] => is′[v]), v)
325+
@preserve_graph tn[v] = replaceinds(tn[v], is[v] => is′[v])
301326
end
302327
for e in edges(is)
303328
isassigned(is, e) || continue
304329
for v in (src(e), dst(e))
305-
setindex_preserve_graph!(tn, replaceinds(tn[v], is[e] => is′[e]), v)
330+
@preserve_graph tn[v] = replaceinds(tn[v], is[e] => is′[e])
306331
end
307332
end
308333
return tn
@@ -361,13 +386,31 @@ end
361386

362387
LinearAlgebra.adjoint(tn::Union{IndsNetwork,AbstractITensorNetwork}) = prime(tn)
363388

364-
#dag(tn::AbstractITensorNetwork) = map_vertex_data(dag, tn)
365-
function ITensors.dag(tn::AbstractITensorNetwork)
366-
tndag = copy(tn)
367-
for v in vertices(tndag)
368-
setindex_preserve_graph!(tndag, dag(tndag[v]), v)
389+
function map_vertex_data(f, tn::AbstractITensorNetwork)
390+
tn = copy(tn)
391+
for v in vertices(tn)
392+
tn[v] = f(tn[v])
369393
end
370-
return tndag
394+
return tn
395+
end
396+
397+
# TODO: Define `@preserve_graph map_vertex_data(f, tn)`
398+
function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork)
399+
tn = copy(tn)
400+
for v in vertices(tn)
401+
@preserve_graph tn[v] = f(tn[v])
402+
end
403+
return tn
404+
end
405+
406+
function Base.conj(tn::AbstractITensorNetwork)
407+
# TODO: Use `@preserve_graph map_vertex_data(f, tn)`
408+
return map_vertex_data_preserve_graph(conj, tn)
409+
end
410+
411+
function ITensors.dag(tn::AbstractITensorNetwork)
412+
# TODO: Use `@preserve_graph map_vertex_data(f, tn)`
413+
return map_vertex_data_preserve_graph(dag, tn)
371414
end
372415

373416
# TODO: should this make sure that internal indices
@@ -442,9 +485,7 @@ function NDTensors.contract(
442485
for n_dst in neighbors_dst
443486
add_edge!(tn, merged_vertex => n_dst)
444487
end
445-
446-
setindex_preserve_graph!(tn, new_itensor, merged_vertex)
447-
488+
@preserve_graph tn[merged_vertex] = new_itensor
448489
return tn
449490
end
450491

@@ -533,13 +574,8 @@ function LinearAlgebra.factorize(
533574
add_edge!(tn, X_vertex => nX)
534575
end
535576
add_edge!(tn, Y_vertex => dst(edge))
536-
537-
# tn[X_vertex] = X
538-
setindex_preserve_graph!(tn, X, X_vertex)
539-
540-
# tn[Y_vertex] = Y
541-
setindex_preserve_graph!(tn, Y, Y_vertex)
542-
577+
@preserve_graph tn[X_vertex] = X
578+
@preserve_graph tn[Y_vertex] = Y
543579
return tn
544580
end
545581

test/test_itensornetwork.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
175175
return inds -> itensor(randn(rng, elt, dim.(inds)...), inds)
176176
end
177177
@test eltype(ψ[first(vertices(ψ))]) == elt
178+
179+
ψc = conj(ψ)
180+
for v in vertices(ψ)
181+
@test ψc[v] == conj(ψ[v])
182+
end
183+
184+
ψd = dag(ψ)
185+
for v in vertices(ψ)
186+
@test ψd[v] == dag(ψ[v])
187+
end
188+
178189
rng = StableRNG(1234)
179190
ψ = ITensorNetwork(g; kwargs...) do v
180191
return inds -> itensor(randn(rng, dim.(inds)...), inds)

0 commit comments

Comments
 (0)