32
32
from pytensor .graph .rewriting .db import EquilibriumDB , SequenceDB
33
33
from pytensor .graph .type import HasShape
34
34
from pytensor .graph .utils import InconsistencyError
35
+ from pytensor .scalar import ScalarConstant
35
36
from pytensor .scan .op import Scan , ScanInfo
36
37
from pytensor .scan .utils import (
37
38
ScanArgs ,
@@ -1184,13 +1185,16 @@ def save_mem_new_scan(fgraph, node):
1184
1185
# index(step) for any output scan actually needs to compute
1185
1186
# In other words n_steps should be equal to this maximal !
1186
1187
# Note: if we have a shared variable that gets updated at every step
1187
- # of the loop, reducing the number of steps will affect the the
1188
- # value of the shared variable after the loop so we need not to
1188
+ # of the loop, reducing the number of steps will affect the
1189
+ # value of the shared variable after the loop so we cannot
1189
1190
# change the number of steps in that case. To do this we set
1190
1191
# global_nsteps to None which is seen as a flag that nothing needs
1191
- # to be done
1192
+ # to be done.
1193
+ # Note: For simplicity while Scans also have global_nsteps set to None.
1194
+ # All step optimizations require knowing the shape of the output, which
1195
+ # cannot be determined from the inputs alone.
1192
1196
assert len (node .outputs ) >= c_outs
1193
- if len (node .outputs ) == c_outs :
1197
+ if len (node .outputs ) == c_outs and not op . info . as_while :
1194
1198
global_nsteps = {"real" : - 1 , "sym" : []}
1195
1199
else :
1196
1200
global_nsteps = None
@@ -1298,9 +1302,9 @@ def save_mem_new_scan(fgraph, node):
1298
1302
1299
1303
# 2.3. Analyze global_nsteps to figure out for how many steps scan
1300
1304
# needs to iterate
1301
- if global_nsteps is not None :
1305
+ if global_nsteps is None :
1302
1306
nw_steps = node .inputs [0 ]
1303
-
1307
+ else :
1304
1308
# there are some symbolic tensors that limit the number of
1305
1309
# steps
1306
1310
if len (global_nsteps ["sym" ]) == 0 :
@@ -1316,16 +1320,15 @@ def save_mem_new_scan(fgraph, node):
1316
1320
real_steps = None
1317
1321
nw_steps = select_min (select_max (sym_steps , real_steps ), node .inputs [0 ])
1318
1322
1323
+ # FIXME: This is not correct. Scan with 0 steps seem to be supported again
1319
1324
# Make sure the ScanSaveMem optimization never makes the new
1320
1325
# number of steps to be 0 (this could happen, for instance, if
1321
1326
# the optimization detects that the outputs of the Scan go through
1322
1327
# subtensor nodes that end up taking no elements) because Scan with
1323
1328
# 0 iterations are not supported. Make sure the new number of steps
1324
1329
# is at least 1.
1325
1330
nw_steps = select_max (nw_steps , 1 )
1326
- else :
1327
- nw_steps = node .inputs [0 ]
1328
- global_nsteps = None
1331
+
1329
1332
1330
1333
# 2.4 Loop over the clients again now looking just to see how many
1331
1334
# intermediate steps to store
@@ -1348,19 +1351,26 @@ def save_mem_new_scan(fgraph, node):
1348
1351
store_steps [i ] = 0
1349
1352
break
1350
1353
1351
- if i > op_info .n_mit_mot :
1352
- length = node .inputs [0 ] + init_l [i ]
1354
+ if (
1355
+ isinstance (this_slice [0 ], ScalarConstant )
1356
+ and this_slice [0 ].value == - 1
1357
+ ):
1358
+ start = nw_steps
1353
1359
else :
1354
- try :
1355
- length = shape_of [out ][0 ]
1356
- except KeyError :
1357
- length = out .shape [0 ]
1358
- cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1360
+ if i > op_info .n_mit_mot :
1361
+ length = node .inputs [0 ] + init_l [i ]
1362
+ else :
1363
+ try :
1364
+ length = shape_of [out ][0 ]
1365
+ except KeyError :
1366
+ length = out .shape [0 ]
1367
+ cf_slice = get_canonical_form_slice (this_slice [0 ], length )
1368
+
1369
+ if isinstance (cf_slice [0 ], slice ):
1370
+ start = at .extract_constant (cf_slice [0 ].start )
1371
+ else :
1372
+ start = at .extract_constant (cf_slice [0 ])
1359
1373
1360
- if isinstance (cf_slice [0 ], slice ):
1361
- start = at .extract_constant (cf_slice [0 ].start )
1362
- else :
1363
- start = at .extract_constant (cf_slice [0 ])
1364
1374
if start == 0 or store_steps [i ] == 0 :
1365
1375
store_steps [i ] = 0
1366
1376
else :
@@ -1514,6 +1524,7 @@ def save_mem_new_scan(fgraph, node):
1514
1524
nw_input = expand_empty (_nw_input , nw_steps )
1515
1525
nw_inputs [in_idx ] = nw_input
1516
1526
else :
1527
+ # FIXME: This is never used
1517
1528
nw_input = nw_inputs [in_idx ][: (initl + nw_steps )]
1518
1529
1519
1530
elif (
@@ -1569,9 +1580,16 @@ def save_mem_new_scan(fgraph, node):
1569
1580
sanitize (cnf_slice [0 ].step ),
1570
1581
)
1571
1582
else :
1572
- fslice = sanitize (cnf_slice [0 ])
1583
+ if (
1584
+ isinstance (old_slices [0 ], ScalarConstant )
1585
+ and this_slice [0 ].value == - 1
1586
+ ):
1587
+ fslice = old_slices [0 ]
1588
+ else :
1589
+ fslice = sanitize (cnf_slice [0 ])
1590
+
1591
+ nw_slice = (fslice ,) + tuple (old_slices [1 :])
1573
1592
1574
- nw_slice = (fslice ,) + tuple (old_slices [1 :])
1575
1593
nw_pos = inv_compress_map [idx ]
1576
1594
1577
1595
subtens = Subtensor (nw_slice )
@@ -1620,9 +1638,15 @@ def save_mem_new_scan(fgraph, node):
1620
1638
) + tuple (old_slices [1 :])
1621
1639
1622
1640
else :
1623
- position = (
1624
- cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1625
- )
1641
+ if (
1642
+ isinstance (old_slices [0 ], ScalarConstant )
1643
+ and this_slice [0 ].value == - 1
1644
+ ):
1645
+ position = old_slices [0 ]
1646
+ else :
1647
+ position = (
1648
+ cnf_slice [0 ] - nw_steps - init_l [pos ] + store_steps [pos ]
1649
+ )
1626
1650
1627
1651
nw_slice = (sanitize (position ),) + tuple (old_slices [1 :])
1628
1652
subtens = Subtensor (nw_slice )
0 commit comments