Skip to content

Commit 9a49994

Browse files
committed
Don't specify zip strict kwarg in hot loops
It seems to add a non-trivial 100ns
1 parent a4709c8 commit 9a49994

File tree

20 files changed

+60
-72
lines changed

20 files changed

+60
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
130130
docstring-code-format = true
131131

132132
[tool.ruff.lint]
133-
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
133+
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
134134
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
135135
unfixable = [
136136
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead

pytensor/compile/builders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,6 @@ def clone(self):
873873

874874
def perform(self, node, inputs, outputs):
875875
variables = self.fn(*inputs)
876-
assert len(variables) == len(outputs)
877-
# strict=False because asserted above
878-
for output, variable in zip(outputs, variables, strict=False):
876+
# zip strict not specified because we are in a hot loop
877+
for output, variable in zip(outputs, variables):
879878
output[0] = variable

pytensor/compile/function/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,8 @@ def __call__(self, *args, output_subset=None, **kwargs):
924924

925925
# Reinitialize each container's 'provided' counter
926926
if trust_input:
927-
for arg_container, arg in zip(input_storage, args, strict=False):
927+
# zip strict not specified because we are in a hot loop
928+
for arg_container, arg in zip(input_storage, args):
928929
arg_container.storage[0] = arg
929930
else:
930931
for arg_container in input_storage:
@@ -934,7 +935,8 @@ def __call__(self, *args, output_subset=None, **kwargs):
934935
raise TypeError("Too many parameter passed to pytensor function")
935936

936937
# Set positional arguments
937-
for arg_container, arg in zip(input_storage, args, strict=False):
938+
# zip strict not specified because we are in a hot loop
939+
for arg_container, arg in zip(input_storage, args):
938940
# See discussion about None as input
939941
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
940942
if arg is None:

pytensor/ifelse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def thunk():
305305
if len(ls) > 0:
306306
return ls
307307
else:
308-
# strict=False because we are in a hot loop
309-
for out, t in zip(outputs, input_true_branch, strict=False):
308+
# zip strict not specified because we are in a hot loop
309+
for out, t in zip(outputs, input_true_branch):
310310
compute_map[out][0] = 1
311311
val = storage_map[t][0]
312312
if self.as_view:
@@ -326,8 +326,8 @@ def thunk():
326326
if len(ls) > 0:
327327
return ls
328328
else:
329-
# strict=False because we are in a hot loop
330-
for out, f in zip(outputs, inputs_false_branch, strict=False):
329+
# zip strict not specified because we are in a hot loop
330+
for out, f in zip(outputs, inputs_false_branch):
331331
compute_map[out][0] = 1
332332
# can't view both outputs unless destroyhandler
333333
# improves

pytensor/link/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,14 +539,14 @@ def make_thunk(self, **kwargs):
539539

540540
def f():
541541
for inputs in input_lists[1:]:
542-
# strict=False because we are in a hot loop
543-
for input1, input2 in zip(inputs0, inputs, strict=False):
542+
# zip strict not specified because we are in a hot loop
543+
for input1, input2 in zip(inputs0, inputs):
544544
input2.storage[0] = copy(input1.storage[0])
545545
for x in to_reset:
546546
x[0] = None
547547
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
548-
# strict=False because we are in a hot loop
549-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
548+
# zip strict not specified because we are in a hot loop
549+
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
550550
try:
551551
wrapper(self.fgraph, i, node, *thunks)
552552
except Exception:
@@ -668,8 +668,8 @@ def thunk(
668668
# since the error may come from any of them?
669669
raise_with_op(self.fgraph, output_nodes[0], thunk)
670670

671-
# strict=False because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
671+
# zip strict not specified because we are in a hot loop
672+
for o_storage, o_val in zip(thunk_outputs, outputs):
673673
o_storage[0] = o_val
674674

675675
thunk.inputs = thunk_inputs

pytensor/link/c/basic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,27 +1988,23 @@ def make_thunk(self, **kwargs):
19881988
)
19891989

19901990
def f():
1991-
# strict=False because we are in a hot loop
1992-
for input1, input2 in zip(i1, i2, strict=False):
1991+
# zip strict not specified because we are in a hot loop
1992+
for input1, input2 in zip(i1, i2):
19931993
# Set the inputs to be the same in both branches.
19941994
# The copy is necessary in order for inplace ops not to
19951995
# interfere.
19961996
input2.storage[0] = copy(input1.storage[0])
1997-
for thunk1, thunk2, node1, node2 in zip(
1998-
thunks1, thunks2, order1, order2, strict=False
1999-
):
2000-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
1997+
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
1998+
for output, storage in zip(node1.outputs, thunk1.outputs):
20011999
if output in no_recycling:
20022000
storage[0] = None
2003-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
2001+
for output, storage in zip(node2.outputs, thunk2.outputs):
20042002
if output in no_recycling:
20052003
storage[0] = None
20062004
try:
20072005
thunk1()
20082006
thunk2()
2009-
for output1, output2 in zip(
2010-
thunk1.outputs, thunk2.outputs, strict=False
2011-
):
2007+
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
20122008
self.checker(output1, output2)
20132009
except Exception:
20142010
raise_with_op(fgraph, node1)

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
312312
else:
313313

314314
def py_perform_return(inputs):
315-
# strict=False because we are in a hot loop
315+
# zip strict not specified because we are in a hot loop
316316
return tuple(
317317
out_type.filter(out[0])
318-
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
318+
for out_type, out in zip(output_types, py_perform(inputs))
319319
)
320320

321321
@numba_njit

pytensor/link/numba/dispatch/cython_support.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def __wrapper_address__(self):
166166
def __call__(self, *args, **kwargs):
167167
# no strict argument because of the JIT
168168
# TODO: check
169-
args = [
170-
dtype(arg)
171-
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
172-
]
169+
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
173170
if self.has_pyx_skip_dispatch():
174171
output = self._pyfunc(*args[:-1], **kwargs)
175172
else:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
186186
new_arr = arr.T.astype(np.float64).copy()
187187
for i, b in enumerate(new_arr):
188188
# no strict argument to this zip because numba doesn't support it
189-
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
189+
for j, (d, v) in enumerate(zip(shape, b)):
190190
if v < 0 or v >= d:
191191
mode_fn(new_arr, i, j, v, d)
192192

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def block_diag(*arrs):
183183

184184
r, c = 0, 0
185185
# no strict argument because it is incompatible with numba
186-
for arr, shape in zip(arrs, shapes): # noqa: B905
186+
for arr, shape in zip(arrs, shapes):
187187
rr, cc = shape
188188
out[r : r + rr, c : c + cc] = arr
189189
r += rr

0 commit comments

Comments
 (0)