Skip to content

Commit dd006eb

Browse files
committed
Upgrade to ruff 0.2.0 and fix RUF017
* --show-source -> --output-format=full * renaming of some config options * removing --line-length because it is already in the pyproject file * taking care of some list quadratic summations
1 parent f2bc707 commit dd006eb

File tree

8 files changed

+41
-36
lines changed

8 files changed

+41
-36
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ repos:
2020
)$
2121
- id: check-merge-conflict
2222
- repo: https://github.com/astral-sh/ruff-pre-commit
23-
rev: v0.1.14
23+
rev: v0.2.0
2424
hooks:
2525
- id: ruff
26-
args: ["--fix", "--show-source"]
26+
args: ["--fix", "--output-format=full"]
2727
- id: ruff-format
28-
args: ["--line-length=88"]
2928
- repo: https://github.com/pre-commit/mirrors-mypy
3029
rev: v1.8.0
3130
hooks:

pyproject.toml

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,19 @@ tag_prefix = "rel-"
121121
addopts = "--durations=50"
122122
testpaths = "tests/"
123123

124-
[tool.pylint]
125-
max-line-length = 88
126-
127-
[tool.pylint.messages_control]
128-
disable = ["C0330", "C0326"]
129-
130-
131124
[tool.ruff]
125+
line-length = 88
126+
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
127+
128+
[tool.ruff.lint]
132129
select = ["C", "E", "F", "I", "UP", "W", "RUF"]
133130
ignore = ["C408", "C901", "E501", "E741", "RUF012"]
134-
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
135131

136132

137-
[tool.ruff.isort]
133+
[tool.ruff.lint.isort]
138134
lines-after-imports = 2
139135

140-
[tool.ruff.per-file-ignores]
136+
[tool.ruff.lint.per-file-ignores]
141137
# TODO: Get rid of these:
142138
"**/__init__.py" = ["F401", "E402", "F403"]
143139
"pytensor/tensor/linalg.py" = ["F403"]

pytensor/graph/destroyhandler.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,11 @@ def fast_inplace_check(fgraph, inputs):
233233
234234
"""
235235
Supervisor = pytensor.compile.function.types.Supervisor
236-
protected_inputs = [
237-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
238-
]
239-
protected_inputs = sum(protected_inputs, []) # flatten the list
236+
protected_inputs = list(
237+
itertools.chain.from_iterable(
238+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
239+
)
240+
)
240241
protected_inputs.extend(fgraph.outputs)
241242

242243
inputs = [

pytensor/scalar/basic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,9 +4096,10 @@ def c_code_cache_version(self):
40964096
return tuple(rval)
40974097

40984098
def c_header_dirs(self, **kwargs):
4099-
rval = sum(
4100-
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
4101-
[],
4099+
rval = list(
4100+
chain.from_iterable(
4101+
subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()
4102+
)
41024103
)
41034104
return rval
41044105

pytensor/scan/rewriting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def apply(self, fgraph):
10671067
if client.op.destroy_map:
10681068
# This flattens the content of destroy_map.values()
10691069
# which is a list of lists
1070-
inplace_inp_indices = sum(client.op.destroy_map.values(), [])
1070+
inplace_inp_indices = chain.from_iterable(
1071+
client.op.destroy_map.values()
1072+
)
10711073

10721074
inplace_inps = [client.inputs[i] for i in inplace_inp_indices]
10731075
if original_node.inputs[inp_idx] in inplace_inps:
@@ -1860,8 +1862,8 @@ def merge(self, nodes):
18601862
# Clone the inner graph of each node independently
18611863
for idx, nd in enumerate(nodes):
18621864
# concatenate all inner_ins and inner_outs of nd
1863-
flat_inner_ins = sum(inner_ins[idx], [])
1864-
flat_inner_outs = sum(inner_outs[idx], [])
1865+
flat_inner_ins = list(chain.from_iterable(inner_ins[idx]))
1866+
flat_inner_outs = list(chain.from_iterable(inner_outs[idx]))
18651867
# clone
18661868
flat_inner_ins, flat_inner_outs = reconstruct_graph(
18671869
flat_inner_ins, flat_inner_outs

pytensor/scan/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@ def field_names(self):
765765
def inner_inputs(self):
766766
return (
767767
self.inner_in_seqs
768-
+ sum(self.inner_in_mit_mot, [])
769-
+ sum(self.inner_in_mit_sot, [])
768+
+ list(chain.from_iterable(self.inner_in_mit_mot))
769+
+ list(chain.from_iterable(self.inner_in_mit_sot))
770770
+ self.inner_in_sit_sot
771771
+ self.inner_in_shared
772772
+ self.inner_in_non_seqs
@@ -788,7 +788,7 @@ def outer_inputs(self):
788788
@property
789789
def inner_outputs(self):
790790
return (
791-
sum(self.inner_out_mit_mot, [])
791+
list(chain.from_iterable(self.inner_out_mit_mot))
792792
+ self.inner_out_mit_sot
793793
+ self.inner_out_sit_sot
794794
+ self.inner_out_nit_sot

pytensor/tensor/rewriting/elemwise.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import sys
23
from collections import defaultdict, deque
34
from collections.abc import Generator
@@ -144,12 +145,12 @@ def apply(self, fgraph):
144145
else:
145146
update_outs = []
146147

147-
protected_inputs = [
148-
f.protected
149-
for f in fgraph._features
150-
if isinstance(f, pytensor.compile.function.types.Supervisor)
151-
]
152-
protected_inputs = sum(protected_inputs, []) # flatten the list
148+
Supervisor = pytensor.compile.function.types.Supervisor
149+
protected_inputs = list(
150+
itertools.chain.from_iterable(
151+
f.protected for f in fgraph._features if isinstance(f, Supervisor)
152+
)
153+
)
153154
protected_inputs.extend(fgraph.outputs)
154155
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
155156
op = node.op

tests/graph/test_fg.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import pickle
23

34
import numpy as np
@@ -574,7 +575,8 @@ def test_remove_in_and_out(self):
574575
assert fg.outputs == [op1_out]
575576
assert op3_out not in fg.clients
576577
assert not any(
577-
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
578+
op3_out.owner in clients
579+
for clients in itertools.chain.from_iterable(fg.clients.values())
578580
)
579581

580582
# Remove an input
@@ -585,7 +587,8 @@ def test_remove_in_and_out(self):
585587
assert fg.inputs == [var2]
586588
assert fg.outputs == []
587589
assert not any(
588-
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
590+
op1_out.owner in clients
591+
for clients in itertools.chain.from_iterable(fg.clients.values())
589592
)
590593

591594
def test_remove_duplicates(self):
@@ -622,10 +625,12 @@ def test_remove_output_empty(self):
622625
assert not fg.apply_nodes
623626
assert op1_out not in fg.clients
624627
assert not any(
625-
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
628+
op1_out.owner in clients
629+
for clients in itertools.chain.from_iterable(fg.clients.values())
626630
)
627631
assert not any(
628-
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
632+
op3_out.owner in clients
633+
for clients in itertools.chain.from_iterable(fg.clients.values())
629634
)
630635

631636
def test_remove_node_multi_out(self):

0 commit comments

Comments
 (0)