Skip to content

Commit f665170

Browse files
committed
Optimize while scans when only last state is needed
1 parent eaedaef commit f665170

File tree

3 files changed

+187
-26
lines changed

3 files changed

+187
-26
lines changed

pytensor/scan/rewriting.py

Lines changed: 135 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
from pytensor.graph.fg import FunctionGraph
2929
from pytensor.graph.op import compute_test_value
3030
from pytensor.graph.replace import clone_replace
31-
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
31+
from pytensor.graph.rewriting.basic import (
32+
GraphRewriter,
33+
copy_stack_trace,
34+
in2out,
35+
node_rewriter,
36+
)
3237
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
38+
from pytensor.graph.rewriting.utils import get_clients_at_depth
3339
from pytensor.graph.type import HasShape
3440
from pytensor.graph.utils import InconsistencyError
41+
from pytensor.raise_op import Assert
42+
from pytensor.scalar import ScalarConstant
3543
from pytensor.scan.op import Scan, ScanInfo
3644
from pytensor.scan.utils import (
3745
ScanArgs,
@@ -1115,6 +1123,61 @@ def sanitize(x):
11151123
return at.as_tensor_variable(x)
11161124

11171125

1126+
@node_rewriter([Scan])
1127+
def merge_while_scan_subtensor_last_element(fgraph, scan_node):
1128+
"""
1129+
Replace while_scan_out[1:][-1] by while_scan_out[-1], for recurring outputs,
1130+
asserting that at least on step will happen. Only the first step can be ensured
1131+
by the inputs alone (i.e., `n_steps > 0` and the non-empty sequences), as
1132+
the while scan could abort earlier anytime after that. This means it is not
1133+
generally safe to replace while_scan[1:][-i] by while_scan[-i] for -i != -1.
1134+
"""
1135+
if not scan_node.op.info.as_while:
1136+
return None
1137+
1138+
recurrent_outputs = scan_node.outputs[: scan_node.op.n_outs]
1139+
1140+
n_steps = scan_node.inputs[0]
1141+
sequences = scan_node.inputs[1 : 1 + scan_node.op.info.n_seqs]
1142+
non_zero_steps_cond = at.all([n_steps > 0] + [seq.shape[0] for seq in sequences])
1143+
assert_non_zero_steps_op = Assert("n_steps > 0 and all(len(sequences) > 0))")
1144+
1145+
subtensor_merge_replacements = {}
1146+
1147+
# Iterate over oll nodes that are two computations below the while scan
1148+
for node in get_clients_at_depth(fgraph, scan_node, depth=2):
1149+
if not isinstance(node.op, Subtensor):
1150+
continue
1151+
1152+
u = node.inputs[0]
1153+
if not (u.owner and isinstance(u.owner.op, Subtensor)):
1154+
continue
1155+
1156+
x = u.owner.inputs[0]
1157+
if x not in recurrent_outputs:
1158+
continue
1159+
1160+
slice1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
1161+
slice2 = get_idx_list(node.inputs, node.op.idx_list)
1162+
1163+
if (
1164+
len(slice1) == 1
1165+
and isinstance(slice1[0], slice)
1166+
and isinstance(slice1[0].start, aes.ScalarConstant)
1167+
and slice1[0].start.data == 1
1168+
and slice1[0].stop is None
1169+
and slice1[0].step is None
1170+
and len(slice2) == 1
1171+
and isinstance(slice2[0], aes.ScalarConstant)
1172+
and slice2[0].data == -1
1173+
):
1174+
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond)
1175+
copy_stack_trace([node.outputs[0], node.inputs[0]], out)
1176+
subtensor_merge_replacements[node.outputs[0]] = out
1177+
1178+
return subtensor_merge_replacements
1179+
1180+
11181181
@node_rewriter([Scan])
11191182
def save_mem_new_scan(fgraph, node):
11201183
r"""Graph optimizer that reduces scan memory consumption.
@@ -1136,6 +1199,17 @@ def save_mem_new_scan(fgraph, node):
11361199
that SITSOT output. Only the most recently computed timestep ever needs to
11371200
be kept in memory.
11381201
1202+
There are two ways in which the Scan buffer size is controlled:
1203+
1. Each recurring output is saved in an input empty tensor x with the initial
1204+
state written at x[0]. The remaining x[1:] positions determine how many
1205+
intermediate results should be stored.
1206+
This rewrite shortens x[1:] to the smallest possible size.
1207+
2. Each non-recurrent outputs (nit-sot) is associated with a scalar integer
1208+
input that determines how many steps should be saved in the perform method.
1209+
This rewrite reduces this number to the smallest possible.
1210+
1211+
The scan perform implementation takes the output sizes into consideration,
1212+
saving the newest results over the oldest ones whenever the buffer is filled.
11391213
"""
11401214
if not isinstance(node.op, Scan):
11411215
return False
@@ -1184,13 +1258,16 @@ def save_mem_new_scan(fgraph, node):
11841258
# index(step) for any output scan actually needs to compute
11851259
# In other words n_steps should be equal to this maximal !
11861260
# Note: if we have a shared variable that gets updated at every step
1187-
# of the loop, reducing the number of steps will affect the the
1188-
# value of the shared variable after the loop so we need not to
1261+
# of the loop, reducing the number of steps will affect the
1262+
# value of the shared variable after the loop so we cannot
11891263
# change the number of steps in that case. To do this we set
11901264
# global_nsteps to None which is seen as a flag that nothing needs
1191-
# to be done
1265+
# to be done.
1266+
# Note: For simplicity while Scans also have global_nsteps set to None.
1267+
# All step optimizations require knowing the shape of the output, which
1268+
# cannot be determined from the inputs alone.
11921269
assert len(node.outputs) >= c_outs
1193-
if len(node.outputs) == c_outs:
1270+
if len(node.outputs) == c_outs and not op.info.as_while:
11941271
global_nsteps = {"real": -1, "sym": []}
11951272
else:
11961273
global_nsteps = None
@@ -1298,9 +1375,9 @@ def save_mem_new_scan(fgraph, node):
12981375

12991376
# 2.3. Analyze global_nsteps to figure out for how many steps scan
13001377
# needs to iterate
1301-
if global_nsteps is not None:
1378+
if global_nsteps is None:
13021379
nw_steps = node.inputs[0]
1303-
1380+
else:
13041381
# there are some symbolic tensors that limit the number of
13051382
# steps
13061383
if len(global_nsteps["sym"]) == 0:
@@ -1316,16 +1393,14 @@ def save_mem_new_scan(fgraph, node):
13161393
real_steps = None
13171394
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
13181395

1396+
# FIXME: This is not correct. Scan with 0 steps seems to be supported
13191397
# Make sure the ScanSaveMem optimization never makes the new
13201398
# number of steps to be 0 (this could happen, for instance, if
13211399
# the optimization detects that the outputs of the Scan go through
13221400
# subtensor nodes that end up taking no elements) because Scan with
13231401
# 0 iterations are not supported. Make sure the new number of steps
13241402
# is at least 1.
13251403
nw_steps = select_max(nw_steps, 1)
1326-
else:
1327-
nw_steps = node.inputs[0]
1328-
global_nsteps = None
13291404

13301405
# 2.4 Loop over the clients again now looking just to see how many
13311406
# intermediate steps to store
@@ -1348,19 +1423,33 @@ def save_mem_new_scan(fgraph, node):
13481423
store_steps[i] = 0
13491424
break
13501425

1351-
if i > op_info.n_mit_mot:
1352-
length = node.inputs[0] + init_l[i]
1426+
# Special case for recurrent outputs where only the last result
1427+
# is requested. This is needed for this rewrite to apply to
1428+
# While Scans at all. Otherwise, `get_canonical_form_slice` in
1429+
# the `else` branch would reintroduce a shape dependency on the
1430+
# original While Scan which would lead this rewrite to abort.
1431+
if (
1432+
i <= op.info.n_mit_mot
1433+
and isinstance(this_slice[0], ScalarConstant)
1434+
and this_slice[0].value == -1
1435+
):
1436+
start = nw_steps
13531437
else:
1354-
try:
1355-
length = shape_of[out][0]
1356-
except KeyError:
1357-
length = out.shape[0]
1358-
cf_slice = get_canonical_form_slice(this_slice[0], length)
1438+
if i <= op.info.n_mit_mot:
1439+
try:
1440+
length = shape_of[out][0]
1441+
except KeyError:
1442+
length = out.shape[0]
1443+
else:
1444+
length = node.inputs[0] + init_l[i]
1445+
1446+
cf_slice = get_canonical_form_slice(this_slice[0], length)
1447+
1448+
if isinstance(cf_slice[0], slice):
1449+
start = at.extract_constant(cf_slice[0].start)
1450+
else:
1451+
start = at.extract_constant(cf_slice[0])
13591452

1360-
if isinstance(cf_slice[0], slice):
1361-
start = at.extract_constant(cf_slice[0].start)
1362-
else:
1363-
start = at.extract_constant(cf_slice[0])
13641453
if start == 0 or store_steps[i] == 0:
13651454
store_steps[i] = 0
13661455
else:
@@ -1514,6 +1603,7 @@ def save_mem_new_scan(fgraph, node):
15141603
nw_input = expand_empty(_nw_input, nw_steps)
15151604
nw_inputs[in_idx] = nw_input
15161605
else:
1606+
# FIXME: This is never used
15171607
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
15181608

15191609
elif (
@@ -1569,9 +1659,16 @@ def save_mem_new_scan(fgraph, node):
15691659
sanitize(cnf_slice[0].step),
15701660
)
15711661
else:
1572-
fslice = sanitize(cnf_slice[0])
1662+
if (
1663+
isinstance(old_slices[0], ScalarConstant)
1664+
and this_slice[0].value == -1
1665+
):
1666+
fslice = old_slices[0]
1667+
else:
1668+
fslice = sanitize(cnf_slice[0])
1669+
1670+
nw_slice = (fslice,) + tuple(old_slices[1:])
15731671

1574-
nw_slice = (fslice,) + tuple(old_slices[1:])
15751672
nw_pos = inv_compress_map[idx]
15761673

15771674
subtens = Subtensor(nw_slice)
@@ -1620,9 +1717,15 @@ def save_mem_new_scan(fgraph, node):
16201717
) + tuple(old_slices[1:])
16211718

16221719
else:
1623-
position = (
1624-
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1625-
)
1720+
if (
1721+
isinstance(old_slices[0], ScalarConstant)
1722+
and this_slice[0].value == -1
1723+
):
1724+
position = old_slices[0]
1725+
else:
1726+
position = (
1727+
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1728+
)
16261729

16271730
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
16281731
subtens = Subtensor(nw_slice)
@@ -2424,6 +2527,12 @@ def push_out_dot1_scan(fgraph, node):
24242527
position=5,
24252528
)
24262529

2530+
scan_eqopt2.register(
2531+
"merge_while_scan_subtensor_last_element",
2532+
in2out(merge_while_scan_subtensor_last_element, ignore_newtrees=True),
2533+
"fast_run",
2534+
"scan",
2535+
)
24272536

24282537
scan_eqopt2.register(
24292538
"constant_folding_for_scan2",

pytensor/tensor/rewriting/subtensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
479479
expresses all slices in a canonical form, and then merges them together.
480480
481481
"""
482+
from pytensor.scan.op import Scan
482483

483484
if isinstance(node.op, Subtensor):
484485
u = node.inputs[0]
@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
489490
# slices of the first applied subtensor
490491
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
491492
slices2 = get_idx_list(node.inputs, node.op.idx_list)
493+
494+
# Don't try to do the optimization on While scan nodes,
495+
# as it will create a dependency on the shape of the outputs
496+
if (
497+
x.owner is not None
498+
and isinstance(x.owner.op, Scan)
499+
and x.owner.op.info.as_while
500+
):
501+
return None
502+
492503
# Get the shapes of the vectors !
493504
try:
494505
# try not to introduce new shape into the graph

tests/scan/test_rewriting.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,47 @@ def f_pow2(x_tm1):
13971397
rng = np.random.default_rng(utt.fetch_seed())
13981398
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
13991399

1400+
def test_while_scan(self):
1401+
x0 = scalar("x0")
1402+
seq = vector("seq")
1403+
n_steps = scalar("n_steps", dtype="int64")
1404+
1405+
# while loop
1406+
(ys, zs), _, = pytensor.scan(
1407+
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
1408+
sequences=[seq],
1409+
outputs_info=[x0, None],
1410+
n_steps=n_steps,
1411+
strict=True,
1412+
)
1413+
# Save memory is triggered by choosing only last value
1414+
y = ys[-1]
1415+
z = zs[-1]
1416+
1417+
f = pytensor.function([x0, seq, n_steps], [y, z])
1418+
1419+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1420+
print(scan_node.inputs)
1421+
_, _, ys_trace, len_zs = scan_node.inputs
1422+
1423+
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
1424+
debug_fn = pytensor.function(
1425+
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
1426+
)
1427+
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
1428+
assert stored_ys_steps == 2
1429+
assert stored_zs_steps == 1
1430+
1431+
test_seq = np.zeros(200)
1432+
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
1433+
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
1434+
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
1435+
with pytest.raises(AssertionError, match="n_steps > 0 and all"):
1436+
f(x0=0, seq=test_seq, n_steps=0)
1437+
# This fails too early inside Scan due to https://github.com/pymc-devs/pytensor/issues/215
1438+
# with pytest.raises(AssertionError, match="n_steps > 0 and all"):
1439+
# f(x0=0, seq=[], n_steps=200)
1440+
14001441

14011442
def test_inner_replace_dot():
14021443
"""

0 commit comments

Comments
 (0)