Skip to content

Commit 1f46fa7

Browse files
committed
Optimize while scans when only last state is needed
1 parent eaedaef commit 1f46fa7

File tree

3 files changed

+242
-26
lines changed

3 files changed

+242
-26
lines changed

pytensor/scan/rewriting.py

Lines changed: 139 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@
2828
from pytensor.graph.fg import FunctionGraph
2929
from pytensor.graph.op import compute_test_value
3030
from pytensor.graph.replace import clone_replace
31-
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
31+
from pytensor.graph.rewriting.basic import (
32+
GraphRewriter,
33+
copy_stack_trace,
34+
in2out,
35+
node_rewriter,
36+
)
3237
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
38+
from pytensor.graph.rewriting.utils import get_clients_at_depth
3339
from pytensor.graph.type import HasShape
3440
from pytensor.graph.utils import InconsistencyError
41+
from pytensor.raise_op import Assert
42+
from pytensor.scalar import ScalarConstant
3543
from pytensor.scan.op import Scan, ScanInfo
3644
from pytensor.scan.utils import (
3745
ScanArgs,
@@ -1115,6 +1123,71 @@ def sanitize(x):
11151123
return at.as_tensor_variable(x)
11161124

11171125

1126+
@node_rewriter([Scan])
1127+
def while_scan_merge_subtensor_last_element(fgraph, scan_node):
1128+
"""
1129+
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
1130+
recurring outputs, asserting that at least one step occurs.
1131+
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
1132+
as the while scan could abort earlier anytime after that. This means it is
1133+
not possible to replace while_scan_out[abs(min(tap)):][-i]
1134+
by while_scan_out[-i], for -i != -1.
1135+
"""
1136+
op = scan_node.op
1137+
1138+
if not op.info.as_while:
1139+
return None
1140+
1141+
# Optimization is not implemented form mit-mot
1142+
recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs(
1143+
scan_node.outputs
1144+
)
1145+
recurrent_outputs_taps_slices = (
1146+
op.info.mit_sot_in_slices + op.info.sit_sot_in_slices
1147+
)
1148+
1149+
n_steps = scan_node.inputs[0]
1150+
non_zero_steps_cond = n_steps > 0
1151+
assert_non_zero_steps_op = Assert("n_steps > 0")
1152+
1153+
subtensor_merge_replacements = {}
1154+
1155+
# Iterate over oll nodes that are two computations below the while scan
1156+
for node in get_clients_at_depth(fgraph, scan_node, depth=2):
1157+
if not isinstance(node.op, Subtensor):
1158+
continue
1159+
1160+
u = node.inputs[0]
1161+
if not (u.owner and isinstance(u.owner.op, Subtensor)):
1162+
continue
1163+
1164+
x = u.owner.inputs[0]
1165+
if x not in recurrent_outputs:
1166+
continue
1167+
1168+
slice1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
1169+
slice2 = get_idx_list(node.inputs, node.op.idx_list)
1170+
1171+
min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)]))
1172+
1173+
if (
1174+
len(slice1) == 1
1175+
and isinstance(slice1[0], slice)
1176+
and isinstance(slice1[0].start, aes.ScalarConstant)
1177+
and slice1[0].start.data == min_tap
1178+
and slice1[0].stop is None
1179+
and slice1[0].step is None
1180+
and len(slice2) == 1
1181+
and isinstance(slice2[0], aes.ScalarConstant)
1182+
and slice2[0].data == -1
1183+
):
1184+
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond)
1185+
copy_stack_trace([node.outputs[0], node.inputs[0]], out)
1186+
subtensor_merge_replacements[node.outputs[0]] = out
1187+
1188+
return subtensor_merge_replacements
1189+
1190+
11181191
@node_rewriter([Scan])
11191192
def save_mem_new_scan(fgraph, node):
11201193
r"""Graph optimizer that reduces scan memory consumption.
@@ -1136,6 +1209,17 @@ def save_mem_new_scan(fgraph, node):
11361209
that SITSOT output. Only the most recently computed timestep ever needs to
11371210
be kept in memory.
11381211
1212+
There are two ways in which the Scan buffer size is controlled:
1213+
1. Each recurring output is saved in an input empty tensor x with the initial
1214+
state written at x[0]. The remaining x[1:] positions determine how many
1215+
intermediate results should be stored.
1216+
This rewrite shortens x[1:] to the smallest possible size.
1217+
2. Each non-recurrent outputs (nit-sot) is associated with a scalar integer
1218+
input that determines how many steps should be saved in the perform method.
1219+
This rewrite reduces this number to the smallest possible.
1220+
1221+
The scan perform implementation takes the output sizes into consideration,
1222+
saving the newest results over the oldest ones whenever the buffer is filled.
11391223
"""
11401224
if not isinstance(node.op, Scan):
11411225
return False
@@ -1184,13 +1268,16 @@ def save_mem_new_scan(fgraph, node):
11841268
# index(step) for any output scan actually needs to compute
11851269
# In other words n_steps should be equal to this maximal !
11861270
# Note: if we have a shared variable that gets updated at every step
1187-
# of the loop, reducing the number of steps will affect the the
1188-
# value of the shared variable after the loop so we need not to
1271+
# of the loop, reducing the number of steps will affect the
1272+
# value of the shared variable after the loop so we cannot
11891273
# change the number of steps in that case. To do this we set
11901274
# global_nsteps to None which is seen as a flag that nothing needs
1191-
# to be done
1275+
# to be done.
1276+
# Note: For simplicity while Scans also have global_nsteps set to None.
1277+
# All step optimizations require knowing the shape of the output, which
1278+
# cannot be determined from the inputs alone.
11921279
assert len(node.outputs) >= c_outs
1193-
if len(node.outputs) == c_outs:
1280+
if len(node.outputs) == c_outs and not op.info.as_while:
11941281
global_nsteps = {"real": -1, "sym": []}
11951282
else:
11961283
global_nsteps = None
@@ -1270,7 +1357,7 @@ def save_mem_new_scan(fgraph, node):
12701357
else:
12711358
# there is a **gotcha** here ! Namely, scan returns an
12721359
# array that contains the initial state of the output
1273-
# as well. Which means that if have a initial state of
1360+
# as well. Which means that if y have a initial state of
12741361
# length 3, and you look for 5 steps you get an output
12751362
# y of length 8. If you only use y[:5], this does not
12761363
# mean that you only need to loop for 5 steps but
@@ -1298,9 +1385,9 @@ def save_mem_new_scan(fgraph, node):
12981385

12991386
# 2.3. Analyze global_nsteps to figure out for how many steps scan
13001387
# needs to iterate
1301-
if global_nsteps is not None:
1388+
if global_nsteps is None:
13021389
nw_steps = node.inputs[0]
1303-
1390+
else:
13041391
# there are some symbolic tensors that limit the number of
13051392
# steps
13061393
if len(global_nsteps["sym"]) == 0:
@@ -1316,16 +1403,14 @@ def save_mem_new_scan(fgraph, node):
13161403
real_steps = None
13171404
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])
13181405

1406+
# FIXME: This is not correct. Scan with 0 steps seems to be supported
13191407
# Make sure the ScanSaveMem optimization never makes the new
13201408
# number of steps to be 0 (this could happen, for instance, if
13211409
# the optimization detects that the outputs of the Scan go through
13221410
# subtensor nodes that end up taking no elements) because Scan with
13231411
# 0 iterations are not supported. Make sure the new number of steps
13241412
# is at least 1.
13251413
nw_steps = select_max(nw_steps, 1)
1326-
else:
1327-
nw_steps = node.inputs[0]
1328-
global_nsteps = None
13291414

13301415
# 2.4 Loop over the clients again now looking just to see how many
13311416
# intermediate steps to store
@@ -1348,19 +1433,33 @@ def save_mem_new_scan(fgraph, node):
13481433
store_steps[i] = 0
13491434
break
13501435

1351-
if i > op_info.n_mit_mot:
1352-
length = node.inputs[0] + init_l[i]
1436+
# Special case for recurrent outputs where only the last result
1437+
# is requested. This is needed for this rewrite to apply to
1438+
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
1439+
# the `else` branch would reintroduce a shape dependency on the
1440+
# original Scan which would lead this rewrite to abort in the end.
1441+
if (
1442+
i <= op.info.n_mit_mot
1443+
and isinstance(this_slice[0], ScalarConstant)
1444+
and this_slice[0].value == -1
1445+
):
1446+
start = nw_steps - 1
13531447
else:
1354-
try:
1355-
length = shape_of[out][0]
1356-
except KeyError:
1357-
length = out.shape[0]
1358-
cf_slice = get_canonical_form_slice(this_slice[0], length)
1448+
if i <= op.info.n_mit_mot:
1449+
try:
1450+
length = shape_of[out][0]
1451+
except KeyError:
1452+
length = out.shape[0]
1453+
else:
1454+
length = node.inputs[0] + init_l[i]
1455+
1456+
cf_slice = get_canonical_form_slice(this_slice[0], length)
1457+
1458+
if isinstance(cf_slice[0], slice):
1459+
start = at.extract_constant(cf_slice[0].start)
1460+
else:
1461+
start = at.extract_constant(cf_slice[0])
13591462

1360-
if isinstance(cf_slice[0], slice):
1361-
start = at.extract_constant(cf_slice[0].start)
1362-
else:
1363-
start = at.extract_constant(cf_slice[0])
13641463
if start == 0 or store_steps[i] == 0:
13651464
store_steps[i] = 0
13661465
else:
@@ -1514,6 +1613,7 @@ def save_mem_new_scan(fgraph, node):
15141613
nw_input = expand_empty(_nw_input, nw_steps)
15151614
nw_inputs[in_idx] = nw_input
15161615
else:
1616+
# FIXME: This is never used
15171617
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]
15181618

15191619
elif (
@@ -1570,8 +1670,8 @@ def save_mem_new_scan(fgraph, node):
15701670
)
15711671
else:
15721672
fslice = sanitize(cnf_slice[0])
1573-
15741673
nw_slice = (fslice,) + tuple(old_slices[1:])
1674+
15751675
nw_pos = inv_compress_map[idx]
15761676

15771677
subtens = Subtensor(nw_slice)
@@ -1620,9 +1720,16 @@ def save_mem_new_scan(fgraph, node):
16201720
) + tuple(old_slices[1:])
16211721

16221722
else:
1623-
position = (
1624-
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1625-
)
1723+
# Special case when only last value is requested
1724+
if (
1725+
isinstance(old_slices[0], ScalarConstant)
1726+
and old_slices[0].value == -1
1727+
):
1728+
position = old_slices[0]
1729+
else:
1730+
position = (
1731+
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
1732+
)
16261733

16271734
nw_slice = (sanitize(position),) + tuple(old_slices[1:])
16281735
subtens = Subtensor(nw_slice)
@@ -2424,6 +2531,12 @@ def push_out_dot1_scan(fgraph, node):
24242531
position=5,
24252532
)
24262533

2534+
scan_eqopt2.register(
2535+
"while_scan_merge_subtensor_last_element",
2536+
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
2537+
"fast_run",
2538+
"scan",
2539+
)
24272540

24282541
scan_eqopt2.register(
24292542
"constant_folding_for_scan2",

pytensor/tensor/rewriting/subtensor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
479479
expresses all slices in a canonical form, and then merges them together.
480480
481481
"""
482+
from pytensor.scan.op import Scan
482483

483484
if isinstance(node.op, Subtensor):
484485
u = node.inputs[0]
@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
489490
# slices of the first applied subtensor
490491
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
491492
slices2 = get_idx_list(node.inputs, node.op.idx_list)
493+
494+
# Don't try to do the optimization on do-while scan outputs,
495+
# as it will create a dependency on the shape of the outputs
496+
if (
497+
x.owner is not None
498+
and isinstance(x.owner.op, Scan)
499+
and x.owner.op.info.as_while
500+
):
501+
return None
502+
492503
# Get the shapes of the vectors !
493504
try:
494505
# try not to introduce new shape into the graph

tests/scan/test_rewriting.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,98 @@ def f_pow2(x_tm1):
13971397
rng = np.random.default_rng(utt.fetch_seed())
13981398
my_f(rng.uniform(size=(3,)), 4, np.int64([2, 2, 3]))
13991399

1400+
def test_while_scan_taps(self):
1401+
n_steps = scalar("n_steps", dtype="int64")
1402+
x0 = vector("x0")
1403+
1404+
ys, _ = pytensor.scan(
1405+
# Fibonacci Sequence
1406+
lambda xtm2, xtm1: (xtm1 + xtm2, {}, until(xtm1 >= 34)),
1407+
outputs_info=[{"initial": x0, "taps": [-2, -1]}],
1408+
n_steps=n_steps,
1409+
)
1410+
# Save memory is triggered by choosing only last value
1411+
y = ys[-1]
1412+
1413+
f = pytensor.function(
1414+
[n_steps, x0], y, mode=get_default_mode().including("scan")
1415+
)
1416+
1417+
np.testing.assert_equal(f(n_steps=1000, x0=[1, 1]), 55)
1418+
np.testing.assert_equal(f(n_steps=1, x0=[1, 1]), 2)
1419+
with pytest.raises(AssertionError, match="n_steps > 0"):
1420+
f(n_steps=0, x0=[1, 1])
1421+
1422+
# ys_trace is an Alloc that controls the size of the inner buffer,
1423+
# it should have shape[0] == 3, with two entries for the taps and one
1424+
# entry for the intermediate output
1425+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1426+
_, ys_trace = scan_node.inputs
1427+
debug_fn = pytensor.function(
1428+
[n_steps, x0], ys_trace.shape[0], accept_inplace=True
1429+
)
1430+
assert debug_fn(n_steps=1000, x0=[1, 1]) == 3
1431+
1432+
def test_while_scan_map(self):
1433+
xs = vector("xs")
1434+
ys, _ = pytensor.scan(
1435+
lambda x: (x + 1, {}, until(x + 1 >= 10)),
1436+
outputs_info=[None],
1437+
sequences=[xs],
1438+
)
1439+
# Save memory is triggered by choosing only last value
1440+
y = ys[-1]
1441+
1442+
f = pytensor.function([xs], y, mode=get_default_mode().including("scan"))
1443+
np.testing.assert_equal(f(xs=np.arange(100)), 10)
1444+
np.testing.assert_equal(f(xs=[0]), 1)
1445+
with pytest.raises(IndexError):
1446+
f(xs=[])
1447+
1448+
# len_ys is a numerical input that controls the shape of the inner buffer
1449+
# It should be 1, as only the last output is needed
1450+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1451+
_, _, len_ys = scan_node.inputs
1452+
debug_fn = pytensor.function([xs], len_ys, accept_inplace=True)
1453+
assert debug_fn(xs=np.arange(100)) == 1
1454+
1455+
def test_while_scan_taps_and_map(self):
1456+
x0 = scalar("x0")
1457+
seq = vector("seq")
1458+
n_steps = scalar("n_steps", dtype="int64")
1459+
1460+
# while loop
1461+
(ys, zs), _, = pytensor.scan(
1462+
lambda s, xtm1: ((xtm1 + 1, xtm1 + 1 + s), {}, until(xtm1 >= 99)),
1463+
sequences=[seq],
1464+
outputs_info=[x0, None],
1465+
n_steps=n_steps,
1466+
)
1467+
# Save memory is triggered by choosing only last value
1468+
y = ys[-1]
1469+
z = zs[-1]
1470+
1471+
f = pytensor.function(
1472+
[x0, seq, n_steps], [y, z], mode=get_default_mode().including("scan")
1473+
)
1474+
test_seq = np.zeros(200, dtype=config.floatX)
1475+
np.testing.assert_allclose(f(x0=0, seq=test_seq, n_steps=200), 100)
1476+
np.testing.assert_allclose(f(x0=1, seq=test_seq, n_steps=20), 21)
1477+
np.testing.assert_allclose(f(x0=np.e, seq=test_seq, n_steps=1), np.e + 1)
1478+
with pytest.raises(AssertionError, match="n_steps > 0"):
1479+
f(x0=0, seq=test_seq, n_steps=0)
1480+
1481+
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
1482+
# If a MissingInputError is raised, it means the rewrite failed
1483+
[scan_node] = (n for n in f.maker.fgraph.apply_nodes if isinstance(n.op, Scan))
1484+
_, _, ys_trace, len_zs = scan_node.inputs
1485+
debug_fn = pytensor.function(
1486+
[n_steps], [ys_trace.shape[0], len_zs], accept_inplace=True
1487+
)
1488+
stored_ys_steps, stored_zs_steps = debug_fn(n_steps=200)
1489+
assert stored_ys_steps == 2
1490+
assert stored_zs_steps == 1
1491+
14001492

14011493
def test_inner_replace_dot():
14021494
"""

0 commit comments

Comments
 (0)