Skip to content

Commit d242d47

Browse files
committed
Add special optimization for While Scan where only last state is used
1 parent 5521d82 commit d242d47

File tree

5 files changed

+91
-26
lines changed

5 files changed

+91
-26
lines changed

pytensor/scan/op.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,6 @@ def __init__(
677677
typeConstructor: Optional[TensorConstructorType] = None,
678678
truncate_gradient: int = -1,
679679
name: Optional[str] = None,
680-
as_while: bool = False,
681680
profile: Optional[Union[str, bool]] = None,
682681
allow_gc: bool = True,
683682
strict: bool = True,

pytensor/scan/rewriting.py

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
3333
from pytensor.graph.type import HasShape
3434
from pytensor.graph.utils import InconsistencyError
35+
from pytensor.scalar import ScalarConstant
3536
from pytensor.scan.op import Scan, ScanInfo
3637
from pytensor.scan.utils import (
3738
ScanArgs,
@@ -1184,13 +1185,16 @@ def save_mem_new_scan(fgraph, node):
11841185
# index(step) for any output scan actually needs to compute
11851186
# In other words n_steps should be equal to this maximal !
11861187
# Note: if we have a shared variable that gets updated at every step
1187-
# of the loop, reducing the number of steps will affect the the
1188-
# value of the shared variable after the loop so we need not to
1188+
# of the loop, reducing the number of steps will affect the
1189+
# value of the shared variable after the loop so we cannot
11891190
# change the number of steps in that case. To do this we set
11901191
# global_nsteps to None which is seen as a flag that nothing needs
1191-
# to be done
1192+
# to be done.
1193+
# Note: For simplicity while Scans also have global_nsteps set to None.
1194+
# All step optimizations require knowing the shape of the output, which
1195+
# cannot be determined from the inputs alone.
11921196
assert len(node.outputs) >= c_outs
1193-
if len(node.outputs) == c_outs:
1197+
if len(node.outputs) == c_outs and not op.info.as_while:
11941198
global_nsteps = {"real": -1, "sym": []}
11951199
else:
11961200
global_nsteps = None
@@ -1298,9 +1302,9 @@ def save_mem_new_scan(fgraph, node):
12981302

12991303
# 2.3. Analyze global_nsteps to figure out for how many steps scan
13001304
# needs to iterate
1301-
if global_nsteps is not None:
1305+
if global_nsteps is None:
13021306
nw_steps = node.inputs[0]
1303-
1307+
else:
13041308
# there are some symbolic tensors that limit the number of
13051309
# steps
13061310
if len(global_nsteps["sym"]) == 0:
@@ -1316,16 +1320,15 @@ def save_mem_new_scan(fgraph, node):
13161320
real_steps = None
13171321
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
13181322

1323+
# FIXME: This is not correct. Scan with 0 steps seem to be supported again
13191324
# Make sure the ScanSaveMem optimization never makes the new
13201325
# number of steps to be 0 (this could happen, for instance, if
13211326
# the optimization detects that the outputs of the Scan go through
13221327
# subtensor nodes that end up taking no elements) because Scan with
13231328
# 0 iterations are not supported. Make sure the new number of steps
13241329
# is at least 1.
13251330
nw_steps = select_max(nw_steps, 1)
1326-
else:
1327-
nw_steps = node.inputs[0]
1328-
global_nsteps = None
1331+
13291332

13301333
# 2.4 Loop over the clients again now looking just to see how many
13311334
# intermediate steps to store
@@ -1348,19 +1351,26 @@ def save_mem_new_scan(fgraph, node):
13481351
store_steps[i] = 0
13491352
break
13501353

1351-
if i > op_info.n_mit_mot:
1352-
length = node.inputs[0] + init_l[i]
1354+
if (
1355+
isinstance(this_slice[0], ScalarConstant)
1356+
and this_slice[0].value == -1
1357+
):
1358+
start = nw_steps
13531359
else:
1354-
try:
1355-
length = shape_of[out][0]
1356-
except KeyError:
1357-
length = out.shape[0]
1358-
cf_slice = get_canonical_form_slice(this_slice[0], length)
1360+
if i > op_info.n_mit_mot:
1361+
length = node.inputs[0] + init_l[i]
1362+
else:
1363+
try:
1364+
length = shape_of[out][0]
1365+
except KeyError:
1366+
length = out.shape[0]
1367+
cf_slice = get_canonical_form_slice(this_slice[0], length)
1368+
1369+
if isinstance(cf_slice[0], slice):
1370+
start = at.extract_constant(cf_slice[0].start)
1371+
else:
1372+
start = at.extract_constant(cf_slice[0])
13591373

1360-
if isinstance(cf_slice[0], slice):
1361-
start = at.extract_constant(cf_slice[0].start)
1362-
else:
1363-
start = at.extract_constant(cf_slice[0])
13641374
if start == 0 or store_steps[i] == 0:
13651375
store_steps[i] = 0
13661376
else:
@@ -1514,6 +1524,7 @@ def save_mem_new_scan(fgraph, node):
15141524
nw_input = expand_empty(_nw_input, nw_steps)
15151525
nw_inputs[in_idx] = nw_input
15161526
else:
1527+
# FIXME: This is never used
15171528
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
15181529

15191530
elif (
@@ -1569,9 +1580,16 @@ def save_mem_new_scan(fgraph, node):
15691580
sanitize(cnf_slice[0].step),
15701581
)
15711582
else:
1572-
fslice = sanitize(cnf_slice[0])
1583+
if (
1584+
isinstance(old_slices[0], ScalarConstant)
1585+
and this_slice[0].value == -1
1586+
):
1587+
fslice = old_slices[0]
1588+
else:
1589+
fslice = sanitize(cnf_slice[0])
1590+
1591+
nw_slice = (fslice,) + tuple(old_slices[1:])
15731592

1574-
nw_slice = (fslice,) + tuple(old_slices[1:])
15751593
nw_pos = inv_compress_map[idx]
15761594

15771595
subtens = Subtensor(nw_slice)
@@ -1620,9 +1638,15 @@ def save_mem_new_scan(fgraph, node):
16201638
) + tuple(old_slices[1:])
16211639

16221640
else:
1623-
position = (
1624-
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1625-
)
1641+
if (
1642+
isinstance(old_slices[0], ScalarConstant)
1643+
and this_slice[0].value == -1
1644+
):
1645+
position = old_slices[0]
1646+
else:
1647+
position = (
1648+
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1649+
)
16261650

16271651
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
16281652
subtens = Subtensor(nw_slice)

pytensor/tensor/rewriting/subtensor.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
479479
expresses all slices in a canonical form, and then merges them together.
480480
481481
"""
482+
from pytensor.scan.op import Scan
482483

483484
if isinstance(node.op, Subtensor):
484485
u = node.inputs[0]
@@ -489,6 +490,25 @@ def local_subtensor_merge(fgraph, node):
489490
# slices of the first applied subtensor
490491
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
491492
slices2 = get_idx_list(node.inputs, node.op.idx_list)
493+
494+
# Special case for scan[1:][-1] = scan[-1]
495+
# FIXME: This assumes scan[1] always exists which is not True,
496+
# because Scans can have 0 steps or 0-length sequences.
497+
# We can fix it by adding an assert that n_steps is
498+
# not zero and no sequence is empty. This generalizes for
499+
# any negative scalar index, although -1 is the most common case.
500+
# TODO: Check that slices1 is indeed [1:]
501+
if (
502+
isinstance(x.owner.op, Scan)
503+
and isinstance(slices2, tuple)
504+
and len(slices2) == 1
505+
and isinstance(slices2[0], aes.ScalarConstant)
506+
and slices2[0].data == -1
507+
):
508+
out = x[-1]
509+
copy_stack_trace([node.outputs[0], node.inputs[0]], out)
510+
return [out]
511+
492512
# Get the shapes of the vectors !
493513
try:
494514
# try not to introduce new shape into the graph

pytensor/tensor/subtensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
)
8383

8484

85+
# FIXME: This type hint is wrong, it returns tuples
8586
def indices_from_subtensor(
8687
op_indices: Iterable[ScalarConstant],
8788
idx_list: Optional[List[Union[Type, slice, Variable]]],

tests/scan/test_rewriting.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,27 @@ def f_pow2(x_tm1):
13971397
rng = np.random.default_rng(utt.fetch_seed())
13981398
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
13991399

1400+
def test_while_scan(self):
1401+
x = scalar("x")
1402+
1403+
ys, _, = pytensor.scan(
1404+
# lambda xtm1: xtm1 + 1, # for loop
1405+
lambda xtm1: (xtm1 + 1, {}, until(xtm1 >= 100)), # while loop
1406+
outputs_info=[x],
1407+
n_steps=100,
1408+
strict=True,
1409+
)
1410+
# Save memory is triggered by choosing only last value
1411+
y = ys[-1]
1412+
1413+
f = pytensor.function([x], y)
1414+
assert f(0) == 100
1415+
1416+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1417+
_, scan_trace = scan_node.inputs
1418+
# This means scan is only saving the last 2 states
1419+
assert scan_trace.type.shape == (2,)
1420+
14001421

14011422
def test_inner_replace_dot():
14021423
"""

0 commit comments

Comments
 (0)