-
Notifications
You must be signed in to change notification settings - Fork 134
Optimize while scans when only last state is needed #216
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,10 +28,18 @@ | |
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.op import compute_test_value | ||
from pytensor.graph.replace import clone_replace | ||
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter | ||
from pytensor.graph.rewriting.basic import ( | ||
GraphRewriter, | ||
copy_stack_trace, | ||
in2out, | ||
node_rewriter, | ||
) | ||
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB | ||
from pytensor.graph.rewriting.utils import get_clients_at_depth | ||
from pytensor.graph.type import HasShape | ||
from pytensor.graph.utils import InconsistencyError | ||
from pytensor.raise_op import Assert | ||
from pytensor.scalar import ScalarConstant | ||
from pytensor.scan.op import Scan, ScanInfo | ||
from pytensor.scan.utils import ( | ||
ScanArgs, | ||
|
@@ -1103,6 +1111,71 @@ def sanitize(x): | |
return at.as_tensor_variable(x) | ||
|
||
|
||
@node_rewriter([Scan]) | ||
def while_scan_merge_subtensor_last_element(fgraph, scan_node): | ||
Armavica marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for | ||
recurring outputs, asserting that at least one step occurs. | ||
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`), | ||
as the while scan could abort earlier anytime after that. This means it is | ||
not possible to replace while_scan_out[abs(min(tap)):][-i] | ||
by while_scan_out[-i], for -i != -1. | ||
""" | ||
op = scan_node.op | ||
|
||
if not op.info.as_while: | ||
return None | ||
|
||
# Optimization is not implemented form mit-mot | ||
recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs( | ||
scan_node.outputs | ||
) | ||
Armavica marked this conversation as resolved.
Show resolved
Hide resolved
|
||
recurrent_outputs_taps_slices = ( | ||
op.info.mit_sot_in_slices + op.info.sit_sot_in_slices | ||
) | ||
|
||
n_steps = scan_node.inputs[0] | ||
non_zero_steps_cond = n_steps > 0 | ||
assert_non_zero_steps_op = Assert("n_steps > 0") | ||
|
||
subtensor_merge_replacements = {} | ||
|
||
# Iterate over all nodes that are two computations below the while scan | ||
for node2 in get_clients_at_depth(fgraph, scan_node, depth=2): | ||
if not isinstance(node2.op, Subtensor): | ||
continue | ||
|
||
node1 = node2.inputs[0].owner | ||
if not (node1 and isinstance(node1.op, Subtensor)): | ||
continue | ||
|
||
x = node1.inputs[0] | ||
if x not in recurrent_outputs: | ||
continue | ||
|
||
slice1 = get_idx_list(node1.inputs, node1.op.idx_list) | ||
slice2 = get_idx_list(node2.inputs, node2.op.idx_list) | ||
|
||
min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)])) | ||
Armavica marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if ( | ||
len(slice1) == 1 | ||
and isinstance(slice1[0], slice) | ||
and isinstance(slice1[0].start, aes.ScalarConstant) | ||
and slice1[0].start.data == min_tap | ||
and slice1[0].stop is None | ||
and slice1[0].step is None | ||
and len(slice2) == 1 | ||
and isinstance(slice2[0], aes.ScalarConstant) | ||
and slice2[0].data == -1 | ||
): | ||
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond) | ||
copy_stack_trace([node2.outputs[0], node2.inputs[0]], out) | ||
subtensor_merge_replacements[node2.outputs[0]] = out | ||
|
||
return subtensor_merge_replacements | ||
|
||
|
||
@node_rewriter([Scan]) | ||
def save_mem_new_scan(fgraph, node): | ||
r"""Graph optimizer that reduces scan memory consumption. | ||
|
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node): | |
that SITSOT output. Only the most recently computed timestep ever needs to | ||
be kept in memory. | ||
|
||
There are two ways in which the Scan buffer size is controlled: | ||
1. Each recurring output is saved in an input empty tensor x with the initial | ||
state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):] | ||
positions determine how many intermediate results should be stored. | ||
This rewrite shortens x[abs(min(taps)):] to the smallest possible size. | ||
2. Each non-recurrent output (nit-sot) is associated with a scalar integer | ||
input that determines how many steps should be saved in the perform method. | ||
This rewrite reduces this number to the smallest possible. | ||
|
||
The scan perform implementation takes the output sizes into consideration, | ||
saving the newest results over the oldest ones whenever the buffer is filled. | ||
""" | ||
if not isinstance(node.op, Scan): | ||
return False | ||
|
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node): | |
# index(step) for any output scan actually needs to compute | ||
# In other words n_steps should be equal to this maximal ! | ||
# Note: if we have a shared variable that gets updated at every step | ||
# of the loop, reducing the number of steps will affect the the | ||
# value of the shared variable after the loop so we need not to | ||
# of the loop, reducing the number of steps will affect the | ||
# value of the shared variable after the loop so we cannot | ||
# change the number of steps in that case. To do this we set | ||
# global_nsteps to None which is seen as a flag that nothing needs | ||
# to be done | ||
# to be done. | ||
# Note: For simplicity while Scans also have global_nsteps set to None. | ||
# All step optimizations require knowing the shape of the output, which | ||
# cannot be determined from the inputs alone. | ||
assert len(node.outputs) >= c_outs | ||
if len(node.outputs) == c_outs: | ||
if len(node.outputs) == c_outs and not op.info.as_while: | ||
global_nsteps = {"real": -1, "sym": []} | ||
else: | ||
global_nsteps = None | ||
|
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node): | |
else: | ||
# there is a **gotcha** here ! Namely, scan returns an | ||
# array that contains the initial state of the output | ||
# as well. Which means that if have a initial state of | ||
# as well. Which means that if y has an initial state of | ||
# length 3, and you look for 5 steps you get an output | ||
# y of length 8. If you only use y[:5], this does not | ||
# mean that you only need to loop for 5 steps but | ||
|
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node): | |
|
||
# 2.3. Analyze global_nsteps to figure out for how many steps scan | ||
# needs to iterate | ||
if global_nsteps is not None: | ||
if global_nsteps is None: | ||
nw_steps = node.inputs[0] | ||
|
||
else: | ||
# there are some symbolic tensors that limit the number of | ||
# steps | ||
if len(global_nsteps["sym"]) == 0: | ||
|
@@ -1303,16 +1390,14 @@ def save_mem_new_scan(fgraph, node): | |
real_steps = None | ||
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0]) | ||
|
||
# FIXME: This is not correct. Scan with 0 steps seems to be supported | ||
# Make sure the ScanSaveMem optimization never makes the new | ||
# number of steps to be 0 (this could happen, for instance, if | ||
# the optimization detects that the outputs of the Scan go through | ||
# subtensor nodes that end up taking no elements) because Scan with | ||
# 0 iterations are not supported. Make sure the new number of steps | ||
# is at least 1. | ||
nw_steps = select_max(nw_steps, 1) | ||
else: | ||
nw_steps = node.inputs[0] | ||
global_nsteps = None | ||
|
||
# 2.4 Loop over the clients again now looking just to see how many | ||
# intermediate steps to store | ||
|
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node): | |
store_steps[i] = 0 | ||
break | ||
|
||
if i > op_info.n_mit_mot: | ||
length = node.inputs[0] + init_l[i] | ||
# Special case for recurrent outputs where only the last result | ||
# is requested. This is needed for this rewrite to apply to | ||
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in | ||
# the `else` branch would reintroduce a shape dependency on the | ||
# original Scan which would lead this rewrite to abort in the end. | ||
Comment on lines
+1423
to
+1427
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have no idea what is happening here, but maybe this will be clearer when I understand the relationships between this rewrite and the new one… There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue as before, the old default branch would go through a Here we add a special case for |
||
if ( | ||
i <= op.info.n_mit_mot | ||
and isinstance(this_slice[0], ScalarConstant) | ||
and this_slice[0].value == -1 | ||
): | ||
start = nw_steps - 1 | ||
else: | ||
try: | ||
length = shape_of[out][0] | ||
except KeyError: | ||
length = out.shape[0] | ||
cf_slice = get_canonical_form_slice(this_slice[0], length) | ||
if i <= op.info.n_mit_mot: | ||
try: | ||
length = shape_of[out][0] | ||
except KeyError: | ||
length = out.shape[0] | ||
else: | ||
length = node.inputs[0] + init_l[i] | ||
|
||
cf_slice = get_canonical_form_slice(this_slice[0], length) | ||
|
||
if isinstance(cf_slice[0], slice): | ||
start = at.extract_constant(cf_slice[0].start) | ||
else: | ||
start = at.extract_constant(cf_slice[0]) | ||
|
||
if isinstance(cf_slice[0], slice): | ||
start = at.extract_constant(cf_slice[0].start) | ||
else: | ||
start = at.extract_constant(cf_slice[0]) | ||
if start == 0 or store_steps[i] == 0: | ||
store_steps[i] = 0 | ||
else: | ||
|
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node): | |
nw_input = expand_empty(_nw_input, nw_steps) | ||
nw_inputs[in_idx] = nw_input | ||
else: | ||
# FIXME: This is never used | ||
nw_input = nw_inputs[in_idx][: (initl + nw_steps)] | ||
|
||
elif ( | ||
|
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node): | |
) | ||
else: | ||
fslice = sanitize(cnf_slice[0]) | ||
|
||
nw_slice = (fslice,) + tuple(old_slices[1:]) | ||
|
||
nw_pos = inv_compress_map[idx] | ||
|
||
subtens = Subtensor(nw_slice) | ||
|
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node): | |
) + tuple(old_slices[1:]) | ||
|
||
else: | ||
position = ( | ||
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] | ||
) | ||
# Special case when only last value is requested | ||
if ( | ||
isinstance(old_slices[0], ScalarConstant) | ||
and old_slices[0].value == -1 | ||
): | ||
position = old_slices[0] | ||
else: | ||
position = ( | ||
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos] | ||
) | ||
|
||
nw_slice = (sanitize(position),) + tuple(old_slices[1:]) | ||
subtens = Subtensor(nw_slice) | ||
|
@@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node): | |
position=5, | ||
) | ||
|
||
scan_eqopt2.register( | ||
"while_scan_merge_subtensor_last_element", | ||
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True), | ||
"fast_run", | ||
"scan", | ||
) | ||
|
||
scan_eqopt2.register( | ||
"constant_folding_for_scan2", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not used anywhere. The information is contained in
info