Skip to content

Commit 33d4d36

Browse files
committed
Fix bug in gradient of Elemwise containing multi-output scalars
1 parent 8051ffb commit 33d4d36

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

pytensor/tensor/elemwise.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,9 @@ def transform(r):
638638
return DimShuffle((), ["x"] * nd)(res)
639639

640640
new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs])
641+
if isinstance(new_r, (list, tuple)):
642+
# Scalar Op with multiple outputs
643+
new_r = new_r[r.owner.outputs.index(r)]
641644
return new_r
642645

643646
ret = []

0 commit comments

Comments
 (0)