@@ -86,19 +86,19 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
86
86
function pullback (ΔC′)
87
87
ΔC = unthunk (ΔC′)
88
88
dC = @thunk projectC (scale (ΔC, conj (β)))
89
- dA = @thunk begin
89
+ dA = @thunk let
90
90
ipA = invperm (linearize (pA))
91
91
_dA = zerovector (A, VectorInterface. promote_add (ΔC, α))
92
92
_dA = tensoradd! (_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj (α), Zero (), ba... )
93
93
return projectA (_dA)
94
94
end
95
- dα = @thunk begin
95
+ dα = @thunk let
96
96
_dα = tensorscalar (tensorcontract (A, ((), linearize (pA)), ! conjA,
97
97
ΔC, (trivtuple (numind (pA)), ()), false ,
98
98
((), ()), One (), ba... ))
99
99
return projectα (_dα)
100
100
end
101
- dβ = @thunk begin
101
+ dβ = @thunk let
102
102
# TODO : consider using `inner`
103
103
_dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pA))), true ,
104
104
ΔC, (trivtuple (numind (pA)), ()), false ,
@@ -165,7 +165,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
165
165
pΔC = (TupleTools. getindices (ipAB, trivtuple (numout (pA))),
166
166
TupleTools. getindices (ipAB, numout (pA) .+ trivtuple (numin (pB))))
167
167
dC = @thunk projectC (scale (ΔC, conj (β)))
168
- dA = @thunk begin
168
+ dA = @thunk let
169
169
ipA = (invperm (linearize (pA)), ())
170
170
conjΔC = conjA
171
171
conjB′ = conjA ? conjB : ! conjB
@@ -177,7 +177,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
177
177
conjA ? α : conj (α), Zero (), ba... )
178
178
return projectA (_dA)
179
179
end
180
- dB = @thunk begin
180
+ dB = @thunk let
181
181
ipB = (invperm (linearize (pB)), ())
182
182
conjΔC = conjB
183
183
conjA′ = conjB ? conjA : ! conjA
@@ -189,15 +189,15 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
189
189
conjB ? α : conj (α), Zero (), ba... )
190
190
return projectB (_dB)
191
191
end
192
- dα = @thunk begin
192
+ dα = @thunk let
193
193
C_αβ = tensorcontract (A, pA, conjA, B, pB, conjB, pAB, One (), ba... )
194
194
# TODO : consider using `inner`
195
195
_dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (pAB))), true ,
196
196
ΔC, (trivtuple (numind (pAB)), ()), false ,
197
197
((), ()), One (), ba... ))
198
198
return projectα (_dα)
199
199
end
200
- dβ = @thunk begin
200
+ dβ = @thunk let
201
201
# TODO : consider using `inner`
202
202
_dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (pAB))), true ,
203
203
ΔC, (trivtuple (numind (pAB)), ()), false ,
@@ -249,7 +249,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
249
249
function pullback (ΔC′)
250
250
ΔC = unthunk (ΔC′)
251
251
dC = @thunk projectC (scale (ΔC, conj (β)))
252
- dA = @thunk begin
252
+ dA = @thunk let
253
253
ip = invperm ((linearize (p)... , q[1 ]. .. , q[2 ]. .. ))
254
254
Es = map (q[1 ], q[2 ]) do i1, i2
255
255
return one (TensorOperations. tensoralloc_add (scalartype (A), A,
@@ -263,15 +263,15 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
263
263
conjA ? α : conj (α), Zero (), ba... )
264
264
return projectA (_dA)
265
265
end
266
- dα = @thunk begin
266
+ dα = @thunk let
267
267
C_αβ = tensortrace (A, p, q, false , One (), ba... )
268
268
_dα = tensorscalar (tensorcontract (C_αβ, ((), trivtuple (numind (p))),
269
269
! conjA,
270
270
ΔC, (trivtuple (numind (p)), ()), false ,
271
271
((), ()), One (), ba... ))
272
272
return projectα (_dα)
273
273
end
274
- dβ = @thunk begin
274
+ dβ = @thunk let
275
275
_dβ = tensorscalar (tensorcontract (C, ((), trivtuple (numind (p))), true ,
276
276
ΔC, (trivtuple (numind (p)), ()), false ,
277
277
((), ()), One (), ba... ))
0 commit comments