Skip to content

Optimize while scans when only last state is needed #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ def __init__(
typeConstructor: Optional[TensorConstructorType] = None,
truncate_gradient: int = -1,
name: Optional[str] = None,
as_while: bool = False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used anywhere. The information is contained in info

profile: Optional[Union[str, bool]] = None,
allow_gc: bool = True,
strict: bool = True,
Expand Down Expand Up @@ -1183,7 +1182,7 @@ def make_node(self, *inputs):
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# in this case is just an int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
new_inputs += [as_tensor_variable(ons) for ons in self.outer_nitsot(inputs)]
Expand Down
165 changes: 139 additions & 26 deletions pytensor/scan/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,18 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import compute_test_value
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from pytensor.graph.rewriting.basic import (
GraphRewriter,
copy_stack_trace,
in2out,
node_rewriter,
)
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
from pytensor.graph.rewriting.utils import get_clients_at_depth
from pytensor.graph.type import HasShape
from pytensor.graph.utils import InconsistencyError
from pytensor.raise_op import Assert
from pytensor.scalar import ScalarConstant
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import (
ScanArgs,
Expand Down Expand Up @@ -1103,6 +1111,71 @@ def sanitize(x):
return at.as_tensor_variable(x)


@node_rewriter([Scan])
def while_scan_merge_subtensor_last_element(fgraph, scan_node):
"""
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
recurring outputs, asserting that at least one step occurs.
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
as the while scan could abort earlier anytime after that. This means it is
not possible to replace while_scan_out[abs(min(tap)):][-i]
by while_scan_out[-i], for -i != -1.
"""
op = scan_node.op

if not op.info.as_while:
return None

# Optimization is not implemented form mit-mot
recurrent_outputs = op.outer_mitsot_outs(scan_node.outputs) + op.outer_sitsot_outs(
scan_node.outputs
)
recurrent_outputs_taps_slices = (
op.info.mit_sot_in_slices + op.info.sit_sot_in_slices
)

n_steps = scan_node.inputs[0]
non_zero_steps_cond = n_steps > 0
assert_non_zero_steps_op = Assert("n_steps > 0")

subtensor_merge_replacements = {}

# Iterate over all nodes that are two computations below the while scan
for node2 in get_clients_at_depth(fgraph, scan_node, depth=2):
if not isinstance(node2.op, Subtensor):
continue

node1 = node2.inputs[0].owner
if not (node1 and isinstance(node1.op, Subtensor)):
continue

x = node1.inputs[0]
if x not in recurrent_outputs:
continue

slice1 = get_idx_list(node1.inputs, node1.op.idx_list)
slice2 = get_idx_list(node2.inputs, node2.op.idx_list)

min_tap = abs(min(recurrent_outputs_taps_slices[recurrent_outputs.index(x)]))

if (
len(slice1) == 1
and isinstance(slice1[0], slice)
and isinstance(slice1[0].start, aes.ScalarConstant)
and slice1[0].start.data == min_tap
and slice1[0].stop is None
and slice1[0].step is None
and len(slice2) == 1
and isinstance(slice2[0], aes.ScalarConstant)
and slice2[0].data == -1
):
out = assert_non_zero_steps_op(x[-1], non_zero_steps_cond)
copy_stack_trace([node2.outputs[0], node2.inputs[0]], out)
subtensor_merge_replacements[node2.outputs[0]] = out

return subtensor_merge_replacements


@node_rewriter([Scan])
def save_mem_new_scan(fgraph, node):
r"""Graph optimizer that reduces scan memory consumption.
Expand All @@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
that SITSOT output. Only the most recently computed timestep ever needs to
be kept in memory.

There are two ways in which the Scan buffer size is controlled:
1. Each recurring output is saved in an input empty tensor x with the initial
state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
positions determine how many intermediate results should be stored.
This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
2. Each non-recurrent output (nit-sot) is associated with a scalar integer
input that determines how many steps should be saved in the perform method.
This rewrite reduces this number to the smallest possible.

The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
"""
if not isinstance(node.op, Scan):
return False
Expand Down Expand Up @@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
# index(step) for any output scan actually needs to compute
# In other words n_steps should be equal to this maximal !
# Note: if we have a shared variable that gets updated at every step
# of the loop, reducing the number of steps will affect the the
# value of the shared variable after the loop so we need not to
# of the loop, reducing the number of steps will affect the
# value of the shared variable after the loop so we cannot
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
# to be done.
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
assert len(node.outputs) >= c_outs
if len(node.outputs) == c_outs:
if len(node.outputs) == c_outs and not op.info.as_while:
global_nsteps = {"real": -1, "sym": []}
else:
global_nsteps = None
Expand Down Expand Up @@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
else:
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output
# as well. Which means that if have a initial state of
# as well. Which means that if y has an initial state of
# length 3, and you look for 5 steps you get an output
# y of length 8. If you only use y[:5], this does not
# mean that you only need to loop for 5 steps but
Expand Down Expand Up @@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):

# 2.3. Analyze global_nsteps to figure out for how many steps scan
# needs to iterate
if global_nsteps is not None:
if global_nsteps is None:
nw_steps = node.inputs[0]

else:
# there are some symbolic tensors that limit the number of
# steps
if len(global_nsteps["sym"]) == 0:
Expand All @@ -1303,16 +1390,14 @@ def save_mem_new_scan(fgraph, node):
real_steps = None
nw_steps = select_min(select_max(sym_steps, real_steps), node.inputs[0])

# FIXME: This is not correct. Scan with 0 steps seems to be supported
# Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through
# subtensor nodes that end up taking no elements) because Scan with
# 0 iterations are not supported. Make sure the new number of steps
# is at least 1.
nw_steps = select_max(nw_steps, 1)
else:
nw_steps = node.inputs[0]
global_nsteps = None

# 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store
Expand All @@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
store_steps[i] = 0
break

if i > op_info.n_mit_mot:
length = node.inputs[0] + init_l[i]
# Special case for recurrent outputs where only the last result
# is requested. This is needed for this rewrite to apply to
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
# the `else` branch would reintroduce a shape dependency on the
# original Scan which would lead this rewrite to abort in the end.
Comment on lines +1423 to +1427
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea what is happening here, but maybe this will be clearer when I understand the relationships between this rewrite and the new one…

Copy link
Member Author

@ricardoV94 ricardoV94 Feb 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as before, the old default branch would go through a get_canonical_form_slice slice, that will reference the length of x. Basically converting x[-1] to x[len(x)-1], which will cause a dependency on the outputs of the old while scan, leading the rewrite to abort in the end.

Here we add a special case for x[-1], to not do that.

if (
i <= op.info.n_mit_mot
and isinstance(this_slice[0], ScalarConstant)
and this_slice[0].value == -1
):
start = nw_steps - 1
else:
try:
length = shape_of[out][0]
except KeyError:
length = out.shape[0]
cf_slice = get_canonical_form_slice(this_slice[0], length)
if i <= op.info.n_mit_mot:
try:
length = shape_of[out][0]
except KeyError:
length = out.shape[0]
else:
length = node.inputs[0] + init_l[i]

cf_slice = get_canonical_form_slice(this_slice[0], length)

if isinstance(cf_slice[0], slice):
start = at.extract_constant(cf_slice[0].start)
else:
start = at.extract_constant(cf_slice[0])

if isinstance(cf_slice[0], slice):
start = at.extract_constant(cf_slice[0].start)
else:
start = at.extract_constant(cf_slice[0])
if start == 0 or store_steps[i] == 0:
store_steps[i] = 0
else:
Expand Down Expand Up @@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
nw_input = expand_empty(_nw_input, nw_steps)
nw_inputs[in_idx] = nw_input
else:
# FIXME: This is never used
nw_input = nw_inputs[in_idx][: (initl + nw_steps)]

elif (
Expand Down Expand Up @@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
)
else:
fslice = sanitize(cnf_slice[0])

nw_slice = (fslice,) + tuple(old_slices[1:])

nw_pos = inv_compress_map[idx]

subtens = Subtensor(nw_slice)
Expand Down Expand Up @@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
) + tuple(old_slices[1:])

else:
position = (
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
)
# Special case when only last value is requested
if (
isinstance(old_slices[0], ScalarConstant)
and old_slices[0].value == -1
):
position = old_slices[0]
else:
position = (
cnf_slice[0] - nw_steps - init_l[pos] + store_steps[pos]
)

nw_slice = (sanitize(position),) + tuple(old_slices[1:])
subtens = Subtensor(nw_slice)
Expand Down Expand Up @@ -2403,6 +2510,12 @@ def push_out_dot1_scan(fgraph, node):
position=5,
)

scan_eqopt2.register(
"while_scan_merge_subtensor_last_element",
in2out(while_scan_merge_subtensor_last_element, ignore_newtrees=True),
"fast_run",
"scan",
)

scan_eqopt2.register(
"constant_folding_for_scan2",
Expand Down
11 changes: 11 additions & 0 deletions pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
expresses all slices in a canonical form, and then merges them together.

"""
from pytensor.scan.op import Scan

if isinstance(node.op, Subtensor):
u = node.inputs[0]
Expand All @@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
# slices of the first applied subtensor
slices1 = get_idx_list(u.owner.inputs, u.owner.op.idx_list)
slices2 = get_idx_list(node.inputs, node.op.idx_list)

# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if (
x.owner is not None
and isinstance(x.owner.op, Scan)
and x.owner.op.info.as_while
):
return None

# Get the shapes of the vectors !
try:
# try not to introduce new shape into the graph
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
def indices_from_subtensor(
op_indices: Iterable[ScalarConstant],
idx_list: Optional[List[Union[Type, slice, Variable]]],
) -> Union[slice, Variable]:
) -> Tuple[Union[slice, Variable], ...]:
"""Recreate the index tuple from which a ``*Subtensor**`` ``Op`` was created.

Parameters
Expand Down
Loading