Skip to content

Commit 2adad27

Browse files
authored
fix ad type instabilities in new zygote (#202)
1 parent c9aaf45 commit 2adad27

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

ext/TensorOperationsChainRulesCoreExt.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -86,19 +86,19 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
8686
function pullback(ΔC′)
8787
ΔC = unthunk(ΔC′)
8888
dC = @thunk projectC(scale(ΔC, conj(β)))
89-
dA = @thunk begin
89+
dA = @thunk let
9090
ipA = invperm(linearize(pA))
9191
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
9292
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
9393
return projectA(_dA)
9494
end
95-
= @thunk begin
95+
= @thunk let
9696
_dα = tensorscalar(tensorcontract(A, ((), linearize(pA)), !conjA,
9797
ΔC, (trivtuple(numind(pA)), ()), false,
9898
((), ()), One(), ba...))
9999
return projectα(_dα)
100100
end
101-
= @thunk begin
101+
= @thunk let
102102
# TODO: consider using `inner`
103103
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pA))), true,
104104
ΔC, (trivtuple(numind(pA)), ()), false,
@@ -165,7 +165,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
165165
pΔC = (TupleTools.getindices(ipAB, trivtuple(numout(pA))),
166166
TupleTools.getindices(ipAB, numout(pA) .+ trivtuple(numin(pB))))
167167
dC = @thunk projectC(scale(ΔC, conj(β)))
168-
dA = @thunk begin
168+
dA = @thunk let
169169
ipA = (invperm(linearize(pA)), ())
170170
conjΔC = conjA
171171
conjB′ = conjA ? conjB : !conjB
@@ -177,7 +177,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
177177
conjA ? α : conj(α), Zero(), ba...)
178178
return projectA(_dA)
179179
end
180-
dB = @thunk begin
180+
dB = @thunk let
181181
ipB = (invperm(linearize(pB)), ())
182182
conjΔC = conjB
183183
conjA′ = conjB ? conjA : !conjA
@@ -189,15 +189,15 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
189189
conjB ? α : conj(α), Zero(), ba...)
190190
return projectB(_dB)
191191
end
192-
= @thunk begin
192+
= @thunk let
193193
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
194194
# TODO: consider using `inner`
195195
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(pAB))), true,
196196
ΔC, (trivtuple(numind(pAB)), ()), false,
197197
((), ()), One(), ba...))
198198
return projectα(_dα)
199199
end
200-
= @thunk begin
200+
= @thunk let
201201
# TODO: consider using `inner`
202202
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(pAB))), true,
203203
ΔC, (trivtuple(numind(pAB)), ()), false,
@@ -249,7 +249,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
249249
function pullback(ΔC′)
250250
ΔC = unthunk(ΔC′)
251251
dC = @thunk projectC(scale(ΔC, conj(β)))
252-
dA = @thunk begin
252+
dA = @thunk let
253253
ip = invperm((linearize(p)..., q[1]..., q[2]...))
254254
Es = map(q[1], q[2]) do i1, i2
255255
return one(TensorOperations.tensoralloc_add(scalartype(A), A,
@@ -263,15 +263,15 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
263263
conjA ? α : conj(α), Zero(), ba...)
264264
return projectA(_dA)
265265
end
266-
= @thunk begin
266+
= @thunk let
267267
C_αβ = tensortrace(A, p, q, false, One(), ba...)
268268
_dα = tensorscalar(tensorcontract(C_αβ, ((), trivtuple(numind(p))),
269269
!conjA,
270270
ΔC, (trivtuple(numind(p)), ()), false,
271271
((), ()), One(), ba...))
272272
return projectα(_dα)
273273
end
274-
= @thunk begin
274+
= @thunk let
275275
_dβ = tensorscalar(tensorcontract(C, ((), trivtuple(numind(p))), true,
276276
ΔC, (trivtuple(numind(p)), ()), false,
277277
((), ()), One(), ba...))

0 commit comments

Comments
 (0)