@@ -39,6 +39,7 @@ using ITensors:
39
39
using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds
40
40
using . ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype
41
41
using LinearAlgebra: LinearAlgebra, factorize
42
+ using MacroTools: @capture
42
43
using NamedGraphs: NamedGraphs, NamedGraph, not_implemented
43
44
using NamedGraphs. GraphsExtensions:
44
45
⊔ , directed_graph, incident_edges, rename_vertices, vertextype
@@ -138,6 +139,30 @@ function setindex_preserve_graph!(tn::AbstractITensorNetwork, value, vertex)
138
139
return tn
139
140
end
140
141
142
+ # TODO : Move to `BaseExtensions` module.
143
+ function is_setindex!_expr (expr:: Expr )
144
+ return is_assignment_expr (expr) && is_getindex_expr (first (expr. args))
145
+ end
146
+ is_setindex!_expr (x) = false
147
+ is_getindex_expr (expr:: Expr ) = (expr. head === :ref )
148
+ is_getindex_expr (x) = false
149
+ is_assignment_expr (expr:: Expr ) = (expr. head === :(= ))
150
+ is_assignment_expr (expr) = false
151
+
152
+ # TODO : Define this in terms of a function mapping
153
+ # preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph
154
+ # preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph
155
+ # Also allow annotating codeblocks like `@views`.
156
+ macro preserve_graph (expr)
157
+ if ! is_setindex!_expr (expr)
158
+ error (
159
+ " preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)" ,
160
+ )
161
+ end
162
+ @capture (expr, array_[indices__] = value_)
163
+ return :(setindex_preserve_graph! ($ (esc (array)), $ (esc (value)), $ (esc .(indices)... )))
164
+ end
165
+
141
166
function ITensors. hascommoninds (tn:: AbstractITensorNetwork , edge:: Pair )
142
167
return hascommoninds (tn, edgetype (tn)(edge))
143
168
end
148
173
149
174
function Base. setindex! (tn:: AbstractITensorNetwork , value, v)
150
175
# v = to_vertex(tn, index...)
151
- setindex_preserve_graph! (tn, value, v)
176
+ @preserve_graph tn[v] = value
152
177
for edge in incident_edges (tn, v)
153
178
rem_edge! (tn, edge)
154
179
end
@@ -297,12 +322,12 @@ function ITensors.replaceinds(
297
322
@assert underlying_graph (is) == underlying_graph (is′)
298
323
for v in vertices (is)
299
324
isassigned (is, v) || continue
300
- setindex_preserve_graph! (tn, replaceinds (tn[v], is[v] => is′[v]), v )
325
+ @preserve_graph tn[v] = replaceinds (tn[v], is[v] => is′[v])
301
326
end
302
327
for e in edges (is)
303
328
isassigned (is, e) || continue
304
329
for v in (src (e), dst (e))
305
- setindex_preserve_graph! (tn, replaceinds (tn[v], is[e] => is′[e]), v )
330
+ @preserve_graph tn[v] = replaceinds (tn[v], is[e] => is′[e])
306
331
end
307
332
end
308
333
return tn
@@ -361,13 +386,31 @@ end
361
386
362
387
LinearAlgebra. adjoint (tn:: Union{IndsNetwork,AbstractITensorNetwork} ) = prime (tn)
363
388
364
- # dag(tn::AbstractITensorNetwork) = map_vertex_data(dag, tn)
365
- function ITensors. dag (tn:: AbstractITensorNetwork )
366
- tndag = copy (tn)
367
- for v in vertices (tndag)
368
- setindex_preserve_graph! (tndag, dag (tndag[v]), v)
389
+ function map_vertex_data (f, tn:: AbstractITensorNetwork )
390
+ tn = copy (tn)
391
+ for v in vertices (tn)
392
+ tn[v] = f (tn[v])
369
393
end
370
- return tndag
394
+ return tn
395
+ end
396
+
397
+ # TODO : Define `@preserve_graph map_vertex_data(f, tn)`
398
+ function map_vertex_data_preserve_graph (f, tn:: AbstractITensorNetwork )
399
+ tn = copy (tn)
400
+ for v in vertices (tn)
401
+ @preserve_graph tn[v] = f (tn[v])
402
+ end
403
+ return tn
404
+ end
405
+
406
+ function Base. conj (tn:: AbstractITensorNetwork )
407
+ # TODO : Use `@preserve_graph map_vertex_data(f, tn)`
408
+ return map_vertex_data_preserve_graph (conj, tn)
409
+ end
410
+
411
+ function ITensors. dag (tn:: AbstractITensorNetwork )
412
+ # TODO : Use `@preserve_graph map_vertex_data(f, tn)`
413
+ return map_vertex_data_preserve_graph (dag, tn)
371
414
end
372
415
373
416
# TODO : should this make sure that internal indices
@@ -442,9 +485,7 @@ function NDTensors.contract(
442
485
for n_dst in neighbors_dst
443
486
add_edge! (tn, merged_vertex => n_dst)
444
487
end
445
-
446
- setindex_preserve_graph! (tn, new_itensor, merged_vertex)
447
-
488
+ @preserve_graph tn[merged_vertex] = new_itensor
448
489
return tn
449
490
end
450
491
@@ -533,13 +574,8 @@ function LinearAlgebra.factorize(
533
574
add_edge! (tn, X_vertex => nX)
534
575
end
535
576
add_edge! (tn, Y_vertex => dst (edge))
536
-
537
- # tn[X_vertex] = X
538
- setindex_preserve_graph! (tn, X, X_vertex)
539
-
540
- # tn[Y_vertex] = Y
541
- setindex_preserve_graph! (tn, Y, Y_vertex)
542
-
577
+ @preserve_graph tn[X_vertex] = X
578
+ @preserve_graph tn[Y_vertex] = Y
543
579
return tn
544
580
end
545
581
0 commit comments