28
28
from pytensor .graph .fg import FunctionGraph
29
29
from pytensor .graph .op import compute_test_value
30
30
from pytensor .graph .replace import clone_replace
31
- from pytensor .graph .rewriting .basic import GraphRewriter , in2out , node_rewriter
31
+ from pytensor .graph .rewriting .basic import (
32
+ GraphRewriter ,
33
+ copy_stack_trace ,
34
+ in2out ,
35
+ node_rewriter ,
36
+ )
32
37
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
38
+ from pytensor .graph .rewriting .utils import get_clients_at_depth
33
39
from pytensor .graph .type import HasShape
34
40
from pytensor .graph .utils import InconsistencyError
41
+ from pytensor .raise_op import Assert
42
+ from pytensor .scalar import ScalarConstant
35
43
from pytensor .scan .op import Scan , ScanInfo
36
44
from pytensor .scan .utils import (
37
45
ScanArgs ,
@@ -1115,6 +1123,71 @@ def sanitize(x):
1115
1123
return at .as_tensor_variable (x )
1116
1124
1117
1125
1126
+ @node_rewriter ([Scan ])
1127
+ def while_scan_merge_subtensor_last_element (fgraph , scan_node ):
1128
+ """
1129
+ Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
1130
+ recurring outputs, asserting that at least one step occurs.
1131
+ Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
1132
+ as the while scan could abort earlier anytime after that. This means it is
1133
+ not possible to replace while_scan_out[abs(min(tap)):][-i]
1134
+ by while_scan_out[-i], for -i != -1.
1135
+ """
1136
+ op = scan_node .op
1137
+
1138
+ if not op .info .as_while :
1139
+ return None
1140
+
1141
+ # Optimization is not implemented form mit-mot
1142
+ recurrent_outputs = op .outer_mitsot_outs (scan_node .outputs ) + op .outer_sitsot_outs (
1143
+ scan_node .outputs
1144
+ )
1145
+ recurrent_outputs_taps_slices = (
1146
+ op .info .mit_sot_in_slices + op .info .sit_sot_in_slices
1147
+ )
1148
+
1149
+ n_steps = scan_node .inputs [0 ]
1150
+ non_zero_steps_cond = n_steps > 0
1151
+ assert_non_zero_steps_op = Assert ("n_steps > 0" )
1152
+
1153
+ subtensor_merge_replacements = {}
1154
+
1155
+ # Iterate over oll nodes that are two computations below the while scan
1156
+ for node in get_clients_at_depth (fgraph , scan_node , depth = 2 ):
1157
+ if not isinstance (node .op , Subtensor ):
1158
+ continue
1159
+
1160
+ u = node .inputs [0 ]
1161
+ if not (u .owner and isinstance (u .owner .op , Subtensor )):
1162
+ continue
1163
+
1164
+ x = u .owner .inputs [0 ]
1165
+ if x not in recurrent_outputs :
1166
+ continue
1167
+
1168
+ slice1 = get_idx_list (u .owner .inputs , u .owner .op .idx_list )
1169
+ slice2 = get_idx_list (node .inputs , node .op .idx_list )
1170
+
1171
+ min_tap = abs (min (recurrent_outputs_taps_slices [recurrent_outputs .index (x )]))
1172
+
1173
+ if (
1174
+ len (slice1 ) == 1
1175
+ and isinstance (slice1 [0 ], slice )
1176
+ and isinstance (slice1 [0 ].start , aes .ScalarConstant )
1177
+ and slice1 [0 ].start .data == min_tap
1178
+ and slice1 [0 ].stop is None
1179
+ and slice1 [0 ].step is None
1180
+ and len (slice2 ) == 1
1181
+ and isinstance (slice2 [0 ], aes .ScalarConstant )
1182
+ and slice2 [0 ].data == - 1
1183
+ ):
1184
+ out = assert_non_zero_steps_op (x [- 1 ], non_zero_steps_cond )
1185
+ copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], out )
1186
+ subtensor_merge_replacements [node .outputs [0 ]] = out
1187
+
1188
+ return subtensor_merge_replacements
1189
+
1190
+
1118
1191
@node_rewriter ([Scan ])
1119
1192
def save_mem_new_scan (fgraph , node ):
1120
1193
r"""Graph optimizer that reduces scan memory consumption.
@@ -1136,6 +1209,17 @@ def save_mem_new_scan(fgraph, node):
1136
1209
that SITSOT output. Only the most recently computed timestep ever needs to
1137
1210
be kept in memory.
1138
1211
1212
+ There are two ways in which the Scan buffer size is controlled:
1213
+ 1. Each recurring output is saved in an input empty tensor x with the initial
1214
+ state written at x[0]. The remaining x[1:] positions determine how many
1215
+ intermediate results should be stored.
1216
+ This rewrite shortens x[1:] to the smallest possible size.
1217
+ 2. Each non-recurrent outputs (nit-sot) is associated with a scalar integer
1218
+ input that determines how many steps should be saved in the perform method.
1219
+ This rewrite reduces this number to the smallest possible.
1220
+
1221
+ The scan perform implementation takes the output sizes into consideration,
1222
+ saving the newest results over the oldest ones whenever the buffer is filled.
1139
1223
"""
1140
1224
if not isinstance (node .op , Scan ):
1141
1225
return False
@@ -1184,13 +1268,16 @@ def save_mem_new_scan(fgraph, node):
1184
1268
# index(step) for any output scan actually needs to compute
1185
1269
# In other words n_steps should be equal to this maximal !
1186
1270
# Note: if we have a shared variable that gets updated at every step
1187
- # of the loop, reducing the number of steps will affect the the
1188
- # value of the shared variable after the loop so we need not to
1271
+ # of the loop, reducing the number of steps will affect the
1272
+ # value of the shared variable after the loop so we cannot
1189
1273
# change the number of steps in that case. To do this we set
1190
1274
# global_nsteps to None which is seen as a flag that nothing needs
1191
- # to be done
1275
+ # to be done.
1276
+ # Note: For simplicity while Scans also have global_nsteps set to None.
1277
+ # All step optimizations require knowing the shape of the output, which
1278
+ # cannot be determined from the inputs alone.
1192
1279
assert len (node .outputs ) >= c_outs
1193
- if len (node .outputs ) == c_outs :
1280
+ if len (node .outputs ) == c_outs and not op . info . as_while :
1194
1281
global_nsteps = {"real" : - 1 , "sym" : []}
1195
1282
else :
1196
1283
global_nsteps = None
@@ -1270,7 +1357,7 @@ def save_mem_new_scan(fgraph, node):
1270
1357
else :
1271
1358
# there is a **gotcha** here ! Namely, scan returns an
1272
1359
# array that contains the initial state of the output
1273
- # as well. Which means that if have a initial state of
1360
+ # as well. Which means that if y have a initial state of
1274
1361
# length 3, and you look for 5 steps you get an output
1275
1362
# y of length 8. If you only use y[:5], this does not
1276
1363
# mean that you only need to loop for 5 steps but
@@ -1298,9 +1385,9 @@ def save_mem_new_scan(fgraph, node):
1298
1385
1299
1386
# 2.3. Analyze global_nsteps to figure out for how many steps scan
1300
1387
# needs to iterate
1301
- if global_nsteps is not None :
1388
+ if global_nsteps is None :
1302
1389
nw_steps = node .inputs [0 ]
1303
-
1390
+ else :
1304
1391
# there are some symbolic tensors that limit the number of
1305
1392
# steps
1306
1393
if len (global_nsteps ["sym" ]) == 0 :
@@ -1316,16 +1403,14 @@ def save_mem_new_scan(fgraph, node):
1316
1403
real_steps = None
1317
1404
nw_steps = select_min (select_max (sym_steps , real_steps ), node .inputs [0 ])
1318
1405
1406
+ # FIXME: This is not correct. Scan with 0 steps seems to be supported
1319
1407
# Make sure the ScanSaveMem optimization never makes the new
1320
1408
# number of steps to be 0 (this could happen, for instance, if
1321
1409
# the optimization detects that the outputs of the Scan go through
1322
1410
# subtensor nodes that end up taking no elements) because Scan with
1323
1411
# 0 iterations are not supported. Make sure the new number of steps
1324
1412
# is at least 1.
1325
1413
nw_steps = select_max (nw_steps , 1 )
1326
- else :
1327
- nw_steps = node .inputs [0 ]
1328
- global_nsteps = None
1329
1414
1330
1415
# 2.4 Loop over the clients again now looking just to see how many
1331
1416
# intermediate steps to store
@@ -1348,19 +1433,33 @@ def save_mem_new_scan(fgraph, node):
1348
1433
store_steps [i ] = 0
1349
1434
break
1350
1435
1351
- if i > op_info .n_mit_mot :
1352
- length = node .inputs [0 ] + init_l [i ]
1436
+ # Special case for recurrent outputs where only the last result
1437
+ # is requested. This is needed for this rewrite to apply to
1438
+ # do-while Scans at all. Otherwise, `get_canonical_form_slice` in
1439
+ # the `else` branch would reintroduce a shape dependency on the
1440
+ # original Scan which would lead this rewrite to abort in the end.
1441
+ if (
1442
+ i <= op .info .n_mit_mot
1443
+ and isinstance (this_slice [0 ], ScalarConstant )
1444
+ and this_slice [0 ].value == - 1
1445
+ ):
1446
+ start = nw_steps - 1
1353
1447
else :
1354
- try :
1355
- length = shape_of [out ][0 ]
1356
- except KeyError :
1357
- length = out .shape [0 ]
1358
- cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1448
+ if i <= op .info .n_mit_mot :
1449
+ try :
1450
+ length = shape_of [out ][0 ]
1451
+ except KeyError :
1452
+ length = out .shape [0 ]
1453
+ else :
1454
+ length = node .inputs [0 ] + init_l [i ]
1455
+
1456
+ cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1457
+
1458
+ if isinstance (cf_slice [0 ], slice ):
1459
+ start = at .extract_constant (cf_slice [0 ].start )
1460
+ else :
1461
+ start = at .extract_constant (cf_slice [0 ])
1359
1462
1360
- if isinstance (cf_slice [0 ], slice ):
1361
- start = at .extract_constant (cf_slice [0 ].start )
1362
- else :
1363
- start = at .extract_constant (cf_slice [0 ])
1364
1463
if start == 0 or store_steps [i ] == 0 :
1365
1464
store_steps [i ] = 0
1366
1465
else :
@@ -1514,6 +1613,7 @@ def save_mem_new_scan(fgraph, node):
1514
1613
nw_input = expand_empty (_nw_input , nw_steps )
1515
1614
nw_inputs [in_idx ] = nw_input
1516
1615
else :
1616
+ # FIXME: This is never used
1517
1617
nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1518
1618
1519
1619
elif (
@@ -1570,8 +1670,8 @@ def save_mem_new_scan(fgraph, node):
1570
1670
)
1571
1671
else :
1572
1672
fslice = sanitize (cnf_slice [0 ])
1573
-
1574
1673
nw_slice = (fslice ,) + tuple (old_slices [1 :])
1674
+
1575
1675
nw_pos = inv_compress_map [idx ]
1576
1676
1577
1677
subtens = Subtensor (nw_slice )
@@ -1620,9 +1720,16 @@ def save_mem_new_scan(fgraph, node):
1620
1720
) + tuple (old_slices [1 :])
1621
1721
1622
1722
else :
1623
- position = (
1624
- cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1625
- )
1723
+ # Special case when only last value is requested
1724
+ if (
1725
+ isinstance (old_slices [0 ], ScalarConstant )
1726
+ and old_slices [0 ].value == - 1
1727
+ ):
1728
+ position = old_slices [0 ]
1729
+ else :
1730
+ position = (
1731
+ cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1732
+ )
1626
1733
1627
1734
nw_slice = (sanitize (position ),) + tuple (old_slices [1 :])
1628
1735
subtens = Subtensor (nw_slice )
@@ -2424,6 +2531,12 @@ def push_out_dot1_scan(fgraph, node):
2424
2531
position = 5 ,
2425
2532
)
2426
2533
2534
+ scan_eqopt2 .register (
2535
+ "while_scan_merge_subtensor_last_element" ,
2536
+ in2out (while_scan_merge_subtensor_last_element , ignore_newtrees = True ),
2537
+ "fast_run" ,
2538
+ "scan" ,
2539
+ )
2427
2540
2428
2541
scan_eqopt2 .register (
2429
2542
"constant_folding_for_scan2" ,
0 commit comments