Skip to content

Speedup Python implementation of Blockwise #1391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
docstring-code-format = true

[tool.ruff.lint]
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
unfixable = [
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead
Expand Down
5 changes: 2 additions & 3 deletions pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ def clone(self):

def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
assert len(variables) == len(outputs)
# strict=False because asserted above
for output, variable in zip(outputs, variables, strict=False):
# zip strict not specified because we are in a hot loop
for output, variable in zip(outputs, variables):
output[0] = variable
6 changes: 4 additions & 2 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,7 +924,8 @@ def __call__(self, *args, output_subset=None, **kwargs):

# Reinitialize each container's 'provided' counter
if trust_input:
for arg_container, arg in zip(input_storage, args, strict=False):
# zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
arg_container.storage[0] = arg
else:
for arg_container in input_storage:
Expand All @@ -934,7 +935,8 @@ def __call__(self, *args, output_subset=None, **kwargs):
raise TypeError("Too many parameter passed to pytensor function")

# Set positional arguments
for arg_container, arg in zip(input_storage, args, strict=False):
# zip strict not specified because we are in a hot loop
for arg_container, arg in zip(input_storage, args):
# See discussion about None as input
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
if arg is None:
Expand Down
41 changes: 27 additions & 14 deletions pytensor/graph/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def make_py_thunk(
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: list[Variable],
debug: bool = False,
) -> ThunkType:
Expand All @@ -513,25 +513,38 @@ def make_py_thunk(
"""
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]
node_compute_map = [compute_map[r] for r in node.outputs]

if debug and hasattr(self, "debug_perform"):
p = node.op.debug_perform
else:
p = node.op.perform

@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r
if compute_map is None:

@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
):
return p(n, [x[0] for x in i], o)

else:
node_compute_map = [compute_map[r] for r in node.outputs]

@is_thunk_type
def rval(
p=p,
i=node_input_storage,
o=node_output_storage,
n=node,
cm=node_compute_map,
):
r = p(n, [x[0] for x in i], o)
for entry in cm:
entry[0] = True
return r

rval.inputs = node_input_storage
rval.outputs = node_output_storage
Expand Down
8 changes: 4 additions & 4 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ def thunk():
if len(ls) > 0:
return ls
else:
# strict=False because we are in a hot loop
for out, t in zip(outputs, input_true_branch, strict=False):
# zip strict not specified because we are in a hot loop
for out, t in zip(outputs, input_true_branch):
compute_map[out][0] = 1
val = storage_map[t][0]
if self.as_view:
Expand All @@ -326,8 +326,8 @@ def thunk():
if len(ls) > 0:
return ls
else:
# strict=False because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch, strict=False):
# zip strict not specified because we are in a hot loop
for out, f in zip(outputs, inputs_false_branch):
compute_map[out][0] = 1
# can't view both outputs unless destroyhandler
# improves
Expand Down
12 changes: 6 additions & 6 deletions pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,14 +539,14 @@ def make_thunk(self, **kwargs):

def f():
for inputs in input_lists[1:]:
# strict=False because we are in a hot loop
for input1, input2 in zip(inputs0, inputs, strict=False):
# zip strict not specified because we are in a hot loop
for input1, input2 in zip(inputs0, inputs):
input2.storage[0] = copy(input1.storage[0])
for x in to_reset:
x[0] = None
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
# strict=False because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
# zip strict not specified because we are in a hot loop
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
try:
wrapper(self.fgraph, i, node, *thunks)
except Exception:
Expand Down Expand Up @@ -668,8 +668,8 @@ def thunk(
# since the error may come from any of them?
raise_with_op(self.fgraph, output_nodes[0], thunk)

# strict=False because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
# zip strict not specified because we are in a hot loop
for o_storage, o_val in zip(thunk_outputs, outputs):
o_storage[0] = o_val

thunk.inputs = thunk_inputs
Expand Down
16 changes: 6 additions & 10 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1988,27 +1988,23 @@ def make_thunk(self, **kwargs):
)

def f():
# strict=False because we are in a hot loop
for input1, input2 in zip(i1, i2, strict=False):
# zip strict not specified because we are in a hot loop
for input1, input2 in zip(i1, i2):
# Set the inputs to be the same in both branches.
# The copy is necessary in order for inplace ops not to
# interfere.
input2.storage[0] = copy(input1.storage[0])
for thunk1, thunk2, node1, node2 in zip(
thunks1, thunks2, order1, order2, strict=False
):
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
for output, storage in zip(node1.outputs, thunk1.outputs):
if output in no_recycling:
storage[0] = None
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
for output, storage in zip(node2.outputs, thunk2.outputs):
if output in no_recycling:
storage[0] = None
try:
thunk1()
thunk2()
for output1, output2 in zip(
thunk1.outputs, thunk2.outputs, strict=False
):
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
self.checker(output1, output2)
except Exception:
raise_with_op(fgraph, node1)
Expand Down
18 changes: 12 additions & 6 deletions pytensor/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
self,
node: Apply,
storage_map: StorageMapType,
compute_map: ComputeMapType,
compute_map: ComputeMapType | None,
no_recycling: Collection[Variable],
) -> CThunkWrapperType:
"""Create a thunk for a C implementation.
Expand Down Expand Up @@ -86,11 +86,17 @@
)
thunk, node_input_filters, node_output_filters = outputs

@is_cthunk_wrapper_type
def rval():
thunk()
for o in node.outputs:
compute_map[o][0] = True
if compute_map is None:
rval = is_cthunk_wrapper_type(thunk)

Check warning on line 90 in pytensor/link/c/op.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/c/op.py#L90

Added line #L90 was not covered by tests

else:
cm_entries = [compute_map[o] for o in node.outputs]

@is_cthunk_wrapper_type
def rval(thunk=thunk, cm_entries=cm_entries):
thunk()
for entry in cm_entries:
entry[0] = True

rval.thunk = thunk
rval.cthunk = thunk.cthunk
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ def py_perform_return(inputs):
else:

def py_perform_return(inputs):
# strict=False because we are in a hot loop
# zip strict not specified because we are in a hot loop
return tuple(
out_type.filter(out[0])
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
for out_type, out in zip(output_types, py_perform(inputs))
)

@numba_njit
Expand Down
5 changes: 1 addition & 4 deletions pytensor/link/numba/dispatch/cython_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,7 @@ def __wrapper_address__(self):
def __call__(self, *args, **kwargs):
# no strict argument because of the JIT
# TODO: check
args = [
dtype(arg)
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
]
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
if self.has_pyx_skip_dispatch():
output = self._pyfunc(*args[:-1], **kwargs)
else:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
new_arr = arr.T.astype(np.float64).copy()
for i, b in enumerate(new_arr):
# no strict argument to this zip because numba doesn't support it
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
for j, (d, v) in enumerate(zip(shape, b)):
if v < 0 or v >= d:
mode_fn(new_arr, i, j, v, d)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/numba/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def block_diag(*arrs):

r, c = 0, 0
# no strict argument because it is incompatible with numba
for arr, shape in zip(arrs, shapes): # noqa: B905
for arr, shape in zip(arrs, shapes):
rr, cc = shape
out[r : r + rr, c : c + cc] = arr
r += rr
Expand Down
10 changes: 5 additions & 5 deletions pytensor/link/numba/dispatch/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs):
shape_aft = x_shape[after_last_axis:]
out_shape = (*shape_bef, *idx_shape, *shape_aft)
out_buffer = np.empty(out_shape, dtype=x.dtype)
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
return out_buffer

Expand Down Expand Up @@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs):
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
return out

Expand All @@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])

for outer in np.ndindex(x_shape[:first_axis]):
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
return out

Expand Down Expand Up @@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
if not len(idxs) == len(vals):
raise ValueError("The number of indices and values must match.")
# no strict argument because incompatible with numba
for idx, val in zip(idxs, vals): # noqa: B905
for idx, val in zip(idxs, vals):
x[idx] = val
return x
else:
Expand Down Expand Up @@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
raise ValueError("The number of indices and values must match.")
# no strict argument because unsupported by numba
# TODO: this doesn't come up in tests
for idx, val in zip(idxs, vals): # noqa: B905
for idx, val in zip(idxs, vals):
x[idx] += val
return x

Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ def streamline_nice_errors_f():
for x in no_recycling:
x[0] = None
try:
# strict=False because we are in a hot loop
for thunk, node in zip(thunks, order, strict=False):
# zip strict not specified because we are in a hot loop
for thunk, node in zip(thunks, order):
thunk()
except Exception:
raise_with_op(fgraph, node, thunk)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4416,8 +4416,8 @@ def make_node(self, *inputs):

def perform(self, node, inputs, output_storage):
outputs = self.py_perform_fn(*inputs)
# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, outputs, strict=False):
# zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val

def grad(self, inputs, output_grads):
Expand Down
4 changes: 2 additions & 2 deletions pytensor/scalar/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage):
for i in range(n_steps):
carry = inner_fn(*carry, *constant)

# strict=False because we are in a hot loop
for storage, out_val in zip(output_storage, carry, strict=False):
# zip strict not specified because we are in a hot loop
for storage, out_val in zip(output_storage, carry):
storage[0] = out_val

@property
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3589,8 +3589,8 @@ def perform(self, node, inp, out):

# Make sure the output is big enough
out_s = []
# strict=False because we are in a hot loop
for xdim, ydim in zip(x_s, y_s, strict=False):
# zip strict not specified because we are in a hot loop
for xdim, ydim in zip(x_s, y_s):
if xdim == ydim:
outdim = xdim
elif xdim == 1:
Expand Down
Loading