Skip to content

Commit 76737d7

Browse files
committed
Upgrade to ruff 0.2.0 and fix RUF017
* --show-source -> --output-format=full * renaming of some config options * removing --line-length because it is already in the pyproject file * taking care of some list quadratic summations
1 parent f2bc707 commit 76737d7

File tree

8 files changed

+30
-30
lines changed

8 files changed

+30
-30
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,11 @@ repos:
2020
)$
2121
- id: check-merge-conflict
2222
- repo: https://github.com/astral-sh/ruff-pre-commit
23-
rev: v0.1.14
23+
rev: v0.2.0
2424
hooks:
2525
- id: ruff
26-
args: ["--fix", "--show-source"]
26+
args: ["--fix", "--output-format=full"]
2727
- id: ruff-format
28-
args: ["--line-length=88"]
2928
- repo: https://github.com/pre-commit/mirrors-mypy
3029
rev: v1.8.0
3130
hooks:

pyproject.toml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,16 @@ tag_prefix = "rel-"
121121
addopts = "--durations=50"
122122
testpaths = "tests/"
123123

124-
[tool.pylint]
125-
max-line-length = 88
126-
127-
[tool.pylint.messages_control]
128-
disable = ["C0330", "C0326"]
129-
130-
131124
[tool.ruff]
125+
line-length = 88
126+
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
127+
128+
[tool.ruff.lint]
132129
select = ["C", "E", "F", "I", "UP", "W", "RUF"]
133130
ignore = ["C408", "C901", "E501", "E741", "RUF012"]
134-
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
135131

136132

137-
[tool.ruff.isort]
133+
[tool.ruff.lint.isort]
138134
lines-after-imports = 2
139135

140136
[tool.ruff.per-file-ignores]

pytensor/graph/destroyhandler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -234,9 +234,11 @@ def fast_inplace_check(fgraph, inputs):
234234
"""
235235
Supervisor = pytensor.compile.function.types.Supervisor
236236
protected_inputs = [
237-
f.protected for f in fgraph._features if isinstance(f, Supervisor)
237+
inpt
238+
for f in fgraph._features
239+
if isinstance(f, Supervisor)
240+
for inpt in f.protected
238241
]
239-
protected_inputs = sum(protected_inputs, []) # flatten the list
240242
protected_inputs.extend(fgraph.outputs)
241243

242244
inputs = [

pytensor/scalar/basic.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,10 +4096,11 @@ def c_code_cache_version(self):
40964096
return tuple(rval)
40974097

40984098
def c_header_dirs(self, **kwargs):
4099-
rval = sum(
4100-
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
4101-
[],
4102-
)
4099+
rval = [
4100+
x
4101+
for subnode in self.fgraph.toposort()
4102+
for x in subnode.op.c_header_dirs(**kwargs)
4103+
]
41034104
return rval
41044105

41054106
def c_support_code(self, **kwargs):

pytensor/scan/rewriting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,7 +1067,9 @@ def apply(self, fgraph):
10671067
if client.op.destroy_map:
10681068
# This flattens the content of destroy_map.values()
10691069
# which is a list of lists
1070-
inplace_inp_indices = sum(client.op.destroy_map.values(), [])
1070+
inplace_inp_indices = [
1071+
x for l in client.op.destroy_map.values() for x in l
1072+
]
10711073

10721074
inplace_inps = [client.inputs[i] for i in inplace_inp_indices]
10731075
if original_node.inputs[inp_idx] in inplace_inps:
@@ -1860,8 +1862,8 @@ def merge(self, nodes):
18601862
# Clone the inner graph of each node independently
18611863
for idx, nd in enumerate(nodes):
18621864
# concatenate all inner_ins and inner_outs of nd
1863-
flat_inner_ins = sum(inner_ins[idx], [])
1864-
flat_inner_outs = sum(inner_outs[idx], [])
1865+
flat_inner_ins = [x for l in inner_ins[idx] for x in l]
1866+
flat_inner_outs = [x for l in inner_outs[idx] for x in l]
18651867
# clone
18661868
flat_inner_ins, flat_inner_outs = reconstruct_graph(
18671869
flat_inner_ins, flat_inner_outs

pytensor/scan/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@ def field_names(self):
765765
def inner_inputs(self):
766766
return (
767767
self.inner_in_seqs
768-
+ sum(self.inner_in_mit_mot, [])
769-
+ sum(self.inner_in_mit_sot, [])
768+
+ [x for l in self.inner_in_mit_mot for x in l]
769+
+ [x for l in self.inner_in_mit_sot for x in l]
770770
+ self.inner_in_sit_sot
771771
+ self.inner_in_shared
772772
+ self.inner_in_non_seqs
@@ -788,7 +788,7 @@ def outer_inputs(self):
788788
@property
789789
def inner_outputs(self):
790790
return (
791-
sum(self.inner_out_mit_mot, [])
791+
[x for l in self.inner_out_mit_mot for x in l]
792792
+ self.inner_out_mit_sot
793793
+ self.inner_out_sit_sot
794794
+ self.inner_out_nit_sot

pytensor/tensor/rewriting/elemwise.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ def apply(self, fgraph):
145145
update_outs = []
146146

147147
protected_inputs = [
148-
f.protected
148+
inpt
149149
for f in fgraph._features
150150
if isinstance(f, pytensor.compile.function.types.Supervisor)
151+
for inpt in f.protected
151152
]
152-
protected_inputs = sum(protected_inputs, []) # flatten the list
153153
protected_inputs.extend(fgraph.outputs)
154154
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
155155
op = node.op

tests/graph/test_fg.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def test_remove_in_and_out(self):
574574
assert fg.outputs == [op1_out]
575575
assert op3_out not in fg.clients
576576
assert not any(
577-
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
577+
op3_out.owner in clients for l in fg.clients.values() for clients in l
578578
)
579579

580580
# Remove an input
@@ -585,7 +585,7 @@ def test_remove_in_and_out(self):
585585
assert fg.inputs == [var2]
586586
assert fg.outputs == []
587587
assert not any(
588-
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
588+
op1_out.owner in clients for l in fg.clients.values() for clients in l
589589
)
590590

591591
def test_remove_duplicates(self):
@@ -622,10 +622,10 @@ def test_remove_output_empty(self):
622622
assert not fg.apply_nodes
623623
assert op1_out not in fg.clients
624624
assert not any(
625-
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
625+
op1_out.owner in clients for l in fg.clients.values() for clients in l
626626
)
627627
assert not any(
628-
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
628+
op3_out.owner in clients for l in fg.clients.values() for clients in l
629629
)
630630

631631
def test_remove_node_multi_out(self):

0 commit comments

Comments
 (0)