28
28
from pytensor .graph .fg import FunctionGraph
29
29
from pytensor .graph .op import compute_test_value
30
30
from pytensor .graph .replace import clone_replace
31
- from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
31
+ from pytensor .graph .rewriting .basic import (
32
+ GraphRewriter ,
33
+ copy_stack_trace ,
34
+ in2out ,
35
+ node_rewriter ,
36
+ )
32
37
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
38
+ from pytensor .graph .rewriting .utils import get_clients_at_depth
33
39
from pytensor .graph .type import HasShape
34
40
from pytensor .graph .utils import InconsistencyError
41
+ from pytensor .raise_op import Assert
42
+ from pytensor .scalar import ScalarConstant
35
43
from pytensor .scan .op import Scan , ScanInfo
36
44
from pytensor .scan .utils import (
37
45
ScanArgs ,
@@ -1115,6 +1123,61 @@ def sanitize(x):
1115
1123
return at .as_tensor_variable (x )
1116
1124
1117
1125
1126
+ @node_rewriter ([Scan ])
1127
+ def merge_while_scan_subtensor_last_element (fgraph , scan_node ):
1128
+ """
1129
+ Replace while_scan_out[1:][-1] by while_scan_out[-1], for recurring outputs,
1130
+ asserting that at least on step will happen. Only the first step can be ensured
1131
+ by the inputs alone (i.e., `n_steps > 0` and the non-empty sequences), as
1132
+ the while scan could abort earlier anytime after that. This means it is not
1133
+ generally safe to replace while_scan[1:][-i] by while_scan[-i] for -i != -1.
1134
+ """
1135
+ if not scan_node .op .info .as_while :
1136
+ return None
1137
+
1138
+ recurrent_outputs = scan_node .outputs [: scan_node .op .n_outs ]
1139
+
1140
+ n_steps = scan_node .inputs [0 ]
1141
+ sequences = scan_node .inputs [1 : 1 + scan_node .op .info .n_seqs ]
1142
+ non_zero_steps_cond = at .all ([n_steps > 0 ] + [seq .shape [0 ] for seq in sequences ])
1143
+ assert_non_zero_steps_op = Assert ("n_steps > 0 and all(len(sequences) > 0))" )
1144
+
1145
+ subtensor_merge_replacements = {}
1146
+
1147
+ # Iterate over oll nodes that are two computations below the while scan
1148
+ for node in get_clients_at_depth (fgraph , scan_node , depth = 2 ):
1149
+ if not isinstance (node .op , Subtensor ):
1150
+ continue
1151
+
1152
+ u = node .inputs [0 ]
1153
+ if not (u .owner and isinstance (u .owner .op , Subtensor )):
1154
+ continue
1155
+
1156
+ x = u .owner .inputs [0 ]
1157
+ if x not in recurrent_outputs :
1158
+ continue
1159
+
1160
+ slice1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
1161
+ slice2 = get_idx_list (node .inputs , node .op .idx_list )
1162
+
1163
+ if (
1164
+ len (slice1 ) == 1
1165
+ and isinstance (slice1 [0 ], slice )
1166
+ and isinstance (slice1 [0 ].start , aes .ScalarConstant )
1167
+ and slice1 [0 ].start .data == 1
1168
+ and slice1 [0 ].stop is None
1169
+ and slice1 [0 ].step is None
1170
+ and len (slice2 ) == 1
1171
+ and isinstance (slice2 [0 ], aes .ScalarConstant )
1172
+ and slice2 [0 ].data == - 1
1173
+ ):
1174
+ out = assert_non_zero_steps_op (x [- 1 ], non_zero_steps_cond )
1175
+ copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], out )
1176
+ subtensor_merge_replacements [node .outputs [0 ]] = out
1177
+
1178
+ return subtensor_merge_replacements
1179
+
1180
+
1118
1181
@node_rewriter ([Scan ])
1119
1182
def save_mem_new_scan (fgraph , node ):
1120
1183
r"""Graph optimizer that reduces scan memory consumption.
@@ -1136,6 +1199,17 @@ def save_mem_new_scan(fgraph, node):
1136
1199
that SITSOT output. Only the most recently computed timestep ever needs to
1137
1200
be kept in memory.
1138
1201
1202
+ There are two ways in which the Scan buffer size is controlled:
1203
+ 1. Each recurring output is saved in an input empty tensor x with the initial
1204
+ state written at x[0]. The remaining x[1:] positions determine how many
1205
+ intermediate results should be stored.
1206
+ This rewrite shortens x[1:] to the smallest possible size.
1207
+ 2. Each non-recurrent outputs (nit-sot) is associated with a scalar integer
1208
+ input that determines how many steps should be saved in the perform method.
1209
+ This rewrite reduces this number to the smallest possible.
1210
+
1211
+ The scan perform implementation takes the output sizes into consideration,
1212
+ saving the newest results over the oldest ones whenever the buffer is filled.
1139
1213
"""
1140
1214
if not isinstance (node .op , Scan ):
1141
1215
return False
@@ -1184,13 +1258,16 @@ def save_mem_new_scan(fgraph, node):
1184
1258
# index(step) for any output scan actually needs to compute
1185
1259
# In other words n_steps should be equal to this maximal !
1186
1260
# Note: if we have a shared variable that gets updated at every step
1187
- # of the loop, reducing the number of steps will affect the the
1188
- # value of the shared variable after the loop so we need not to
1261
+ # of the loop, reducing the number of steps will affect the
1262
+ # value of the shared variable after the loop so we cannot
1189
1263
# change the number of steps in that case. To do this we set
1190
1264
# global_nsteps to None which is seen as a flag that nothing needs
1191
- # to be done
1265
+ # to be done.
1266
+ # Note: For simplicity while Scans also have global_nsteps set to None.
1267
+ # All step optimizations require knowing the shape of the output, which
1268
+ # cannot be determined from the inputs alone.
1192
1269
assert len (node .outputs ) >= c_outs
1193
- if len (node .outputs ) == c_outs :
1270
+ if len (node .outputs ) == c_outs and not op . info . as_while :
1194
1271
global_nsteps = {"real" : - 1 , "sym" : []}
1195
1272
else :
1196
1273
global_nsteps = None
@@ -1298,9 +1375,9 @@ def save_mem_new_scan(fgraph, node):
1298
1375
1299
1376
# 2.3. Analyze global_nsteps to figure out for how many steps scan
1300
1377
# needs to iterate
1301
- if global_nsteps is not None :
1378
+ if global_nsteps is None :
1302
1379
nw_steps = node .inputs [0 ]
1303
-
1380
+ else :
1304
1381
# there are some symbolic tensors that limit the number of
1305
1382
# steps
1306
1383
if len (global_nsteps ["sym" ]) == 0 :
@@ -1316,16 +1393,14 @@ def save_mem_new_scan(fgraph, node):
1316
1393
real_steps = None
1317
1394
nw_steps = select_min (select_max (sym_steps , real_steps ), node .inputs [0 ])
1318
1395
1396
+ # FIXME: This is not correct. Scan with 0 steps seems to be supported
1319
1397
# Make sure the ScanSaveMem optimization never makes the new
1320
1398
# number of steps to be 0 (this could happen, for instance, if
1321
1399
# the optimization detects that the outputs of the Scan go through
1322
1400
# subtensor nodes that end up taking no elements) because Scan with
1323
1401
# 0 iterations are not supported. Make sure the new number of steps
1324
1402
# is at least 1.
1325
1403
nw_steps = select_max (nw_steps , 1 )
1326
- else :
1327
- nw_steps = node .inputs [0 ]
1328
- global_nsteps = None
1329
1404
1330
1405
# 2.4 Loop over the clients again now looking just to see how many
1331
1406
# intermediate steps to store
@@ -1348,19 +1423,33 @@ def save_mem_new_scan(fgraph, node):
1348
1423
store_steps [i ] = 0
1349
1424
break
1350
1425
1351
- if i > op_info .n_mit_mot :
1352
- length = node .inputs [0 ] + init_l [i ]
1426
+ # Special case for recurrent outputs where only the last result
1427
+ # is requested. This is needed for this rewrite to apply to
1428
+ # While Scans at all. Otherwise, `get_canonical_form_slice` in
1429
+ # the `else` branch would reintroduce a shape dependency on the
1430
+ # original While Scan which would lead this rewrite to abort.
1431
+ if (
1432
+ i <= op .info .n_mit_mot
1433
+ and isinstance (this_slice [0 ], ScalarConstant )
1434
+ and this_slice [0 ].value == - 1
1435
+ ):
1436
+ start = nw_steps
1353
1437
else :
1354
- try :
1355
- length = shape_of [out ][0 ]
1356
- except KeyError :
1357
- length = out .shape [0 ]
1358
- cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1438
+ if i <= op .info .n_mit_mot :
1439
+ try :
1440
+ length = shape_of [out ][0 ]
1441
+ except KeyError :
1442
+ length = out .shape [0 ]
1443
+ else :
1444
+ length = node .inputs [0 ] + init_l [i ]
1445
+
1446
+ cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1447
+
1448
+ if isinstance (cf_slice [0 ], slice ):
1449
+ start = at .extract_constant (cf_slice [0 ].start )
1450
+ else :
1451
+ start = at .extract_constant (cf_slice [0 ])
1359
1452
1360
- if isinstance (cf_slice [0 ], slice ):
1361
- start = at .extract_constant (cf_slice [0 ].start )
1362
- else :
1363
- start = at .extract_constant (cf_slice [0 ])
1364
1453
if start == 0 or store_steps [i ] == 0 :
1365
1454
store_steps [i ] = 0
1366
1455
else :
@@ -1514,6 +1603,7 @@ def save_mem_new_scan(fgraph, node):
1514
1603
nw_input = expand_empty (_nw_input , nw_steps )
1515
1604
nw_inputs [in_idx ] = nw_input
1516
1605
else :
1606
+ # FIXME: This is never used
1517
1607
nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1518
1608
1519
1609
elif (
@@ -1569,9 +1659,16 @@ def save_mem_new_scan(fgraph, node):
1569
1659
sanitize (cnf_slice [0 ].step ),
1570
1660
)
1571
1661
else :
1572
- fslice = sanitize (cnf_slice [0 ])
1662
+ if (
1663
+ isinstance (old_slices [0 ], ScalarConstant )
1664
+ and this_slice [0 ].value == - 1
1665
+ ):
1666
+ fslice = old_slices [0 ]
1667
+ else :
1668
+ fslice = sanitize (cnf_slice [0 ])
1669
+
1670
+ nw_slice = (fslice ,) + tuple (old_slices [1 :])
1573
1671
1574
- nw_slice = (fslice ,) + tuple (old_slices [1 :])
1575
1672
nw_pos = inv_compress_map [idx ]
1576
1673
1577
1674
subtens = Subtensor (nw_slice )
@@ -1620,9 +1717,15 @@ def save_mem_new_scan(fgraph, node):
1620
1717
) + tuple (old_slices [1 :])
1621
1718
1622
1719
else :
1623
- position = (
1624
- cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1625
- )
1720
+ if (
1721
+ isinstance (old_slices [0 ], ScalarConstant )
1722
+ and this_slice [0 ].value == - 1
1723
+ ):
1724
+ position = old_slices [0 ]
1725
+ else :
1726
+ position = (
1727
+ cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1728
+ )
1626
1729
1627
1730
nw_slice = (sanitize (position ),) + tuple (old_slices [1 :])
1628
1731
subtens = Subtensor (nw_slice )
@@ -2424,6 +2527,12 @@ def push_out_dot1_scan(fgraph, node):
2424
2527
position = 5 ,
2425
2528
)
2426
2529
2530
+ scan_eqopt2 .register (
2531
+ "merge_while_scan_subtensor_last_element" ,
2532
+ in2out (merge_while_scan_subtensor_last_element , ignore_newtrees = True ),
2533
+ "fast_run" ,
2534
+ "scan" ,
2535
+ )
2427
2536
2428
2537
scan_eqopt2 .register (
2429
2538
"constant_folding_for_scan2" ,
0 commit comments