@@ -21,8 +21,8 @@ function processcontractions(ex, treebuilder, treesorter, costcheck)
21
21
elseif isexpr (ex, :macrocall ) && ex. args[1 ] == Symbol (" @notensor" )
22
22
return ex
23
23
elseif isexpr (ex, :call ) && ex. args[1 ] == :tensorscalar
24
- return Expr ( :call , :tensorscalar ,
25
- processcontractions (ex . args[ 2 ], treebuilder, treesorter, costcheck))
24
+ return processcontractions (ex . args[ 2 ], treebuilder, treesorter, costcheck)
25
+ # `tensorscalar` will be reinserted automatically
26
26
elseif isassignment (ex) || isdefinition (ex)
27
27
lhs, rhs = getlhs (ex), getrhs (ex)
28
28
rhs, pre, post = _processcontractions (rhs, treebuilder, treesorter, costcheck)
57
57
58
58
function insertcontractiontrees! (ex, treebuilder, treesorter, costcheck, preexprs,
59
59
postexprs)
60
+ if isexpr (ex, :call ) && ex. args[1 ] == :tensorscalar
61
+ return insertcontractiontrees! (ex. args[2 ], treebuilder, treesorter, costcheck,
62
+ preexprs, postexprs)
63
+ end
60
64
if isexpr (ex, :call )
61
65
args = ex. args
62
66
nargs = length (args)
63
67
ex = Expr (:call , args[1 ],
64
68
(insertcontractiontrees! (args[i], treebuilder, treesorter, costcheck,
65
69
preexprs, postexprs) for i in 2 : nargs). .. )
66
70
end
67
- if istensorcontraction (ex) && length (ex. args) > 3
71
+ if ! istensorcontraction (ex)
72
+ return ex
73
+ end
74
+ if length (ex. args) <= 3
75
+ return isempty (getindices (ex)) ? Expr (:call , :tensorscalar , ex) : ex
76
+ else
68
77
args = ex. args[2 : end ]
69
78
network = map (getindices, args)
70
79
for a in getallindices (ex)
@@ -137,7 +146,6 @@ function insertcontractiontrees!(ex, treebuilder, treesorter, costcheck, preexpr
137
146
push! (postexprs, removelinenumbernode (costcompareex))
138
147
return treeex
139
148
end
140
- return ex
141
149
end
142
150
143
151
function treecost (tree, network, costs)
@@ -175,9 +183,14 @@ function defaulttreesorter(args, tree)
175
183
if isa (tree, Int)
176
184
return args[tree]
177
185
else
178
- return Expr (:call , :* ,
179
- defaulttreesorter (args, tree[1 ]),
180
- defaulttreesorter (args, tree[2 ]))
186
+ ex = Expr (:call , :* ,
187
+ defaulttreesorter (args, tree[1 ]),
188
+ defaulttreesorter (args, tree[2 ]))
189
+ if isempty (getindices (ex))
190
+ return Expr (:call , :tensorscalar , ex)
191
+ else
192
+ return ex
193
+ end
181
194
end
182
195
end
183
196
0 commit comments