Skip to content

Commit 332fb26

Browse files
committed
Specialize string representation of Dimshuffle
1 parent 97d81e9 commit 332fb26

File tree

3 files changed

+30
-22
lines changed

3 files changed

+30
-22
lines changed

pytensor/tensor/elemwise.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,18 @@ def make_node(self, _input):
215215
return Apply(self, [input], [output])
216216

217217
def __str__(self):
218-
if self.inplace:
219-
return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
220-
else:
221-
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
218+
shuffle = sorted(self.shuffle) != self.shuffle
219+
if self.augment and not (shuffle or self.drop):
220+
if len(self.augment) == 1:
221+
return f"ExpandDims{{axis={self.augment[0]}}}"
222+
return f"ExpandDims{{axes={self.augment}}}"
223+
if self.drop and not (self.augment or shuffle):
224+
if len(self.drop) == 1:
225+
return f"DropDims{{axis={self.drop[0]}}}"
226+
return f"DropDims{{axes={self.drop}}}"
227+
if shuffle and not (self.augment or self.drop):
228+
return f"Transpose{{axes={self.shuffle}}}"
229+
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
222230

223231
def perform(self, node, inp, out, params):
224232
(res,) = inp

tests/scan/test_printing.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def test_debugprint_sitsot():
3737
│ │ │ │ │ └─ Subtensor{int64} [id H]
3838
│ │ │ │ │ ├─ Shape [id I]
3939
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
40-
│ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id K]
40+
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
4141
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L]
4242
│ │ │ │ │ │ ├─ A [id M]
43-
│ │ │ │ │ │ └─ InplaceDimShuffle{x} [id N]
43+
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
4444
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
4545
│ │ │ │ │ └─ ScalarConstant{0} [id P]
4646
│ │ │ │ └─ Subtensor{int64} [id Q]
@@ -95,10 +95,10 @@ def test_debugprint_sitsot_no_extra_info():
9595
│ │ │ │ │ └─ Subtensor{int64} [id H]
9696
│ │ │ │ │ ├─ Shape [id I]
9797
│ │ │ │ │ │ └─ Unbroadcast{0} [id J]
98-
│ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id K]
98+
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
9999
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L]
100100
│ │ │ │ │ │ ├─ A [id M]
101-
│ │ │ │ │ │ └─ InplaceDimShuffle{x} [id N]
101+
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
102102
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
103103
│ │ │ │ │ └─ ScalarConstant{0} [id P]
104104
│ │ │ │ └─ Subtensor{int64} [id Q]
@@ -264,7 +264,7 @@ def compute_A_k(A, k):
264264
265265
for{cpu,scan_fn} [id B]
266266
← Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
267-
├─ InplaceDimShuffle{x} [id Z]
267+
├─ ExpandDims{axis=0} [id Z]
268268
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
269269
└─ Elemwise{pow,no_inplace} [id BB]
270270
├─ Subtensor{int64} [id BC]
@@ -278,10 +278,10 @@ def compute_A_k(A, k):
278278
│ │ │ │ │ │ └─ Subtensor{int64} [id BJ]
279279
│ │ │ │ │ │ ├─ Shape [id BK]
280280
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
281-
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id BM]
281+
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
282282
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BN]
283283
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
284-
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id BP]
284+
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
285285
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
286286
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
287287
│ │ │ │ │ └─ Subtensor{int64} [id BS]
@@ -297,7 +297,7 @@ def compute_A_k(A, k):
297297
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
298298
│ │ └─ ScalarConstant{1} [id BW]
299299
│ └─ ScalarConstant{-1} [id BX]
300-
└─ InplaceDimShuffle{x} [id BY]
300+
└─ ExpandDims{axis=0} [id BY]
301301
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
302302
303303
for{cpu,scan_fn} [id BE]
@@ -361,7 +361,7 @@ def compute_A_k(A, k):
361361
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
362362
→ *3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
363363
← Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
364-
├─ InplaceDimShuffle{x} [id BD]
364+
├─ ExpandDims{axis=0} [id BD]
365365
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
366366
└─ Elemwise{pow,no_inplace} [id BE]
367367
├─ Subtensor{int64} [id BF]
@@ -375,10 +375,10 @@ def compute_A_k(A, k):
375375
│ │ │ │ │ │ └─ Subtensor{int64} [id BL]
376376
│ │ │ │ │ │ ├─ Shape [id BM]
377377
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
378-
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id BO]
378+
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
379379
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BP]
380380
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0)
381-
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id BQ]
381+
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
382382
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
383383
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
384384
│ │ │ │ │ └─ Subtensor{int64} [id BT]
@@ -394,7 +394,7 @@ def compute_A_k(A, k):
394394
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
395395
│ │ └─ ScalarConstant{1} [id BX]
396396
│ └─ ScalarConstant{-1} [id BY]
397-
└─ InplaceDimShuffle{x} [id BZ]
397+
└─ ExpandDims{axis=0} [id BZ]
398398
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
399399
400400
for{cpu,scan_fn} [id BH]
@@ -515,10 +515,10 @@ def test_debugprint_mitmot():
515515
│ │ │ │ │ │ │ └─ Subtensor{int64} [id K]
516516
│ │ │ │ │ │ │ ├─ Shape [id L]
517517
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
518-
│ │ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id N]
518+
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
519519
│ │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id O]
520520
│ │ │ │ │ │ │ │ ├─ A [id P]
521-
│ │ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id Q]
521+
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
522522
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
523523
│ │ │ │ │ │ │ └─ ScalarConstant{0} [id S]
524524
│ │ │ │ │ │ └─ Subtensor{int64} [id T]
@@ -559,22 +559,22 @@ def test_debugprint_mitmot():
559559
│ │ │ ├─ Elemwise{second,no_inplace} [id BN]
560560
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
561561
│ │ │ │ │ └─ ···
562-
│ │ │ │ └─ InplaceDimShuffle{x,x} [id BO]
562+
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
563563
│ │ │ │ └─ TensorConstant{0.0} [id BP]
564564
│ │ │ ├─ IncSubtensor{Inc;int64} [id BQ]
565565
│ │ │ │ ├─ Elemwise{second,no_inplace} [id BR]
566566
│ │ │ │ │ ├─ Subtensor{int64::} [id BS]
567567
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
568568
│ │ │ │ │ │ │ └─ ···
569569
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
570-
│ │ │ │ │ └─ InplaceDimShuffle{x,x} [id BU]
570+
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
571571
│ │ │ │ │ └─ TensorConstant{0.0} [id BV]
572572
│ │ │ │ ├─ Elemwise{second} [id BW]
573573
│ │ │ │ │ ├─ Subtensor{int64} [id BX]
574574
│ │ │ │ │ │ ├─ Subtensor{int64::} [id BS]
575575
│ │ │ │ │ │ │ └─ ···
576576
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
577-
│ │ │ │ │ └─ InplaceDimShuffle{x} [id BZ]
577+
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
578578
│ │ │ │ │ └─ Elemwise{second,no_inplace} [id CA]
579579
│ │ │ │ │ ├─ Sum{acc_dtype=float64} [id CB]
580580
│ │ │ │ │ │ └─ Subtensor{int64} [id BX]

tests/test_printing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def test_debugprint():
275275
exp_res = dedent(
276276
r"""
277277
Elemwise{Composite{(i2 + (i0 - i1))}} 4
278-
├─ InplaceDimShuffle{x,0} v={0: [0]} 3
278+
├─ ExpandDims{axis=0} v={0: [0]} 3
279279
│ └─ CGemv{inplace} d={0: [0]} 2
280280
│ ├─ AllocEmpty{dtype='float64'} 1
281281
│ │ └─ Shape_i{0} 0

0 commit comments

Comments
 (0)