Skip to content

Activate the RUF lints #618

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.0
hooks:
- id: ruff
args: ["--fix", "--show-source"]
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=88"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
Expand Down
18 changes: 7 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,19 @@ tag_prefix = "rel-"
addopts = "--durations=50"
testpaths = "tests/"

[tool.pylint]
max-line-length = 88

[tool.pylint.messages_control]
disable = ["C0330", "C0326"]


[tool.ruff]
select = ["C", "E", "F", "I", "UP", "W"]
ignore = ["C408", "C901", "E501", "E741"]
line-length = 88
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]

[tool.ruff.lint]
select = ["C", "E", "F", "I", "UP", "W", "RUF"]
ignore = ["C408", "C901", "E501", "E741", "RUF012"]


[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# TODO: Get rid of these:
"**/__init__.py" = ["F401", "E402", "F403"]
"pytensor/tensor/linalg.py" = ["F403"]
Expand Down
4 changes: 2 additions & 2 deletions pytensor/breakpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def make_node(self, condition, *monitored_vars):
new_op.inp_types.append(monitored_vars[i].type)

# Build the Apply node
inputs = [condition] + list(monitored_vars)
inputs = [condition, *monitored_vars]
outputs = [inp.type() for inp in monitored_vars]
return Apply(op=new_op, inputs=inputs, outputs=outputs)

Expand Down Expand Up @@ -142,7 +142,7 @@ def perform(self, node, inputs, output_storage):
output_storage[i][0] = inputs[i + 1]

def grad(self, inputs, output_gradients):
return [DisconnectedType()()] + output_gradients
return [DisconnectedType()(), *output_gradients]

def infer_shape(self, fgraph, inputs, input_shapes):
# Return the shape of every input but the condition (first input)
Expand Down
4 changes: 3 additions & 1 deletion pytensor/compile/compilelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def force_unlock(lock_dir: os.PathLike):

@contextmanager
def lock_ctx(
lock_dir: Union[str, os.PathLike] = None, *, timeout: Optional[float] = None
lock_dir: Optional[Union[str, os.PathLike]] = None,
*,
timeout: Optional[float] = None,
):
"""Context manager that wraps around FileLock and SoftFileLock from filelock package.

Expand Down
4 changes: 2 additions & 2 deletions pytensor/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,9 +892,9 @@ def _get_preallocated_maps(

# Use the same step on all dimensions before the last check_ndim.
if all(s == 1 for s in out_shape[:-check_ndim]):
step_signs_list = [(1,)] + step_signs_list
step_signs_list = [(1,), *step_signs_list]
else:
step_signs_list = [(-1, 1)] + step_signs_list
step_signs_list = [(-1, 1), *step_signs_list]

for step_signs in itertools_product(*step_signs_list):
for step_size in (1, 2):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/sharedvalue.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def shared(value, name=None, strict=False, allow_downcast=None, **kwargs):
add_tag_trace(var)
return var
except MemoryError as e:
e.args = e.args + ("Consider using `pytensor.shared(..., borrow=True)`",)
e.args = (*e.args, "Consider using `pytensor.shared(..., borrow=True)`")
raise


Expand Down
9 changes: 6 additions & 3 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,8 @@ def add_caching_dir_configvars():
"fft_tiling",
"winograd",
"winograd_non_fused",
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
)

SUPPORTED_DNN_CONV_ALGO_BWD_DATA = (
"none",
Expand All @@ -1391,7 +1392,8 @@ def add_caching_dir_configvars():
"fft_tiling",
"winograd",
"winograd_non_fused",
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
)

SUPPORTED_DNN_CONV_ALGO_BWD_FILTER = (
"none",
Expand All @@ -1400,7 +1402,8 @@ def add_caching_dir_configvars():
"small",
"winograd_non_fused",
"fft_tiling",
) + SUPPORTED_DNN_CONV_ALGO_RUNTIME
*SUPPORTED_DNN_CONV_ALGO_RUNTIME,
)

SUPPORTED_DNN_CONV_PRECISION = (
"as_input_f32",
Expand Down
2 changes: 1 addition & 1 deletion pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,7 @@ def inner_function(*args):
jacobs, updates = pytensor.scan(
inner_function,
sequences=pytensor.tensor.arange(expression.shape[0]),
non_sequences=[expression] + wrt,
non_sequences=[expression, *wrt],
)
assert not updates, "Scan has returned a list of updates; this should not happen."
return as_list_or_tuple(using_list, using_tuple, jacobs)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ def __hash__(self):
return hash((type(self), self.id, self.type))

def __repr__(self):
return f"{type(self).__name__}({repr(self.id)}, {repr(self.type)})"
return f"{type(self).__name__}({self.id!r}, {self.type!r})"

def signature(self) -> tuple[_TypeType, _IdType]:
return (self.type, self.id)
Expand Down Expand Up @@ -774,7 +774,7 @@ def __repr__(self):
data_str = repr(self.data)
if len(data_str) > 20:
data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
return f"{type(self).__name__}({repr(self.type)}, data={data_str})"
return f"{type(self).__name__}({self.type!r}, data={data_str})"

def clone(self, **kwargs):
return self
Expand Down
9 changes: 5 additions & 4 deletions pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ def fast_inplace_check(fgraph, inputs):

"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = [
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.extend(fgraph.outputs)

inputs = [
Expand Down
12 changes: 7 additions & 5 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,13 @@ def consistent_(self, fgraph):


class ReplaceValidate(History, Validator):
pickle_rm_attr = (
["replace_validate", "replace_all_validate", "replace_all_validate_remove"]
+ History.pickle_rm_attr
+ Validator.pickle_rm_attr
)
pickle_rm_attr = [
"replace_validate",
"replace_all_validate",
"replace_all_validate_remove",
*History.pickle_rm_attr,
*Validator.pickle_rm_attr,
]

def on_attach(self, fgraph):
for attr in (
Expand Down
7 changes: 4 additions & 3 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def print_profile(cls, stream, prof, level=0):
else:
name = rewrite.name
idx = rewrites.index(rewrite)
ll.append((name, rewrite.__class__.__name__, idx) + nb_n)
ll.append((name, rewrite.__class__.__name__, idx, *nb_n))
lll = sorted(zip(prof, ll), key=lambda a: a[0])

for t, rewrite in lll[::-1]:
Expand Down Expand Up @@ -1091,7 +1091,7 @@ def __str__(self):
return getattr(self, "__name__", repr(self))

def __repr__(self):
return f"FromFunctionNodeRewriter({repr(self.fn)}, {repr(self._tracks)}, {repr(self.requirements)})"
return f"FromFunctionNodeRewriter({self.fn!r}, {self._tracks!r}, {self.requirements!r})"

def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print(f"{' ' * level}{self.transform} id={id(self)}", file=stream)
Expand Down Expand Up @@ -1138,7 +1138,8 @@ def decorator(f):
req = requirements
if inplace:
dh_handler = dh.DestroyHandler
req = tuple(requirements) + (
req = (
*requirements,
lambda fgraph: fgraph.attach_feature(dh_handler()),
)
rval = FromFunctionNodeRewriter(f, tracks, req)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def register(

if use_db_name_as_tag:
if self.name is not None:
tags = tags + (self.name,)
tags = (*tags, self.name)

rewriter.name = name
# This restriction is there because in many place we suppose that
Expand Down
4 changes: 2 additions & 2 deletions pytensor/graph/rewriting/unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __str__(self):
return f"~{self.token} [{self.constraint}]"

def __repr__(self):
return f"{type(self).__name__}({repr(self.constraint)}, {self.token})"
return f"{type(self).__name__}({self.constraint!r}, {self.token})"


def car_Variable(x):
Expand Down Expand Up @@ -283,7 +283,7 @@ def _convert(y):
var_map[pattern] = v
return v
elif isinstance(y, tuple):
return etuple(*tuple(_convert(e) for e in y))
return etuple(*(_convert(e) for e in y))
elif isinstance(y, (Number, np.ndarray)):
from pytensor.tensor import as_tensor_variable

Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, *args, **kwargs):
assert list(kwargs.keys()) == ["variable"]
error_msg = get_variable_trace_string(kwargs["variable"])
if error_msg:
args = args + (error_msg,)
args = (*args, error_msg)
s = "\n".join(args) # Needed to have the new line print correctly
super().__init__(s)

Expand Down
20 changes: 10 additions & 10 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def make_node(self, condition: "TensorLike", *true_false_branches: Any):

return Apply(
self,
[condition] + new_inputs_true_branch + new_inputs_false_branch,
[condition, *new_inputs_true_branch, *new_inputs_false_branch],
output_vars,
)

Expand Down Expand Up @@ -275,11 +275,11 @@ def grad(self, ins, grads):
# condition + epsilon always triggers the same branch as condition
condition_grad = condition.zeros_like().astype(config.floatX)

return (
[condition_grad]
+ if_true_op(*inputs_true_grad, return_list=True)
+ if_false_op(*inputs_false_grad, return_list=True)
)
return [
condition_grad,
*if_true_op(*inputs_true_grad, return_list=True),
*if_false_op(*inputs_false_grad, return_list=True),
]

def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
cond = node.inputs[0]
Expand Down Expand Up @@ -397,7 +397,7 @@ def ifelse(

new_ifelse = IfElse(n_outs=len(then_branch), as_view=False, name=name)

ins = [condition] + list(then_branch) + list(else_branch)
ins = [condition, *then_branch, *else_branch]
rval = new_ifelse(*ins, return_list=True)

if rval_type is None:
Expand Down Expand Up @@ -611,7 +611,7 @@ def apply(self, fgraph):
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
mn_name = "?"
if merging_node.op.name:
mn_name = merging_node.op.name
Expand Down Expand Up @@ -673,7 +673,7 @@ def cond_remove_identical(fgraph, node):

new_ifelse = IfElse(n_outs=len(nw_ts), as_view=op.as_view, name=op.name)

new_ins = [node.inputs[0]] + nw_ts + nw_fs
new_ins = [node.inputs[0], *nw_ts, *nw_fs]
new_outs = new_ifelse(*new_ins, return_list=True)

rval = []
Expand Down Expand Up @@ -711,7 +711,7 @@ def cond_merge_random_op(fgraph, main_node):
mn_fs = merging_node.inputs[1:][merging_node.op.n_outs :]
pl_ts = proposal.inputs[1:][: proposal.op.n_outs]
pl_fs = proposal.inputs[1:][proposal.op.n_outs :]
new_ins = [merging_node.inputs[0]] + mn_ts + pl_ts + mn_fs + pl_fs
new_ins = [merging_node.inputs[0], *mn_ts, *pl_ts, *mn_fs, *pl_fs]
mn_name = "?"
if merging_node.op.name:
mn_name = merging_node.op.name
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __set__(self, value: Any) -> None:
self.storage[0] = self.type.filter(value, **kwargs)

except Exception as e:
e.args = e.args + (f'Container name "{self.name}"',)
e.args = (*e.args, f'Container name "{self.name}"')
raise

data = property(__get__, __set__)
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ def unpickle_failure():
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
if key_data.keys and next(iter(key_data.keys))[0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
Expand Down Expand Up @@ -1890,7 +1890,7 @@ def _try_compile_tmp(
os.close(fd)
fd = None
out, err, p_ret = output_subprocess_Popen(
[compiler] + args + [path, "-o", exe_path] + flags
[compiler, *args, path, "-o", exe_path, *flags]
)
if p_ret != 0:
compilation_ok = False
Expand Down
6 changes: 3 additions & 3 deletions pytensor/link/c/params_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def __hash__(self):
.signature()
for i in range(self.__params_type__.length)
)
return hash((type(self), self.__params_type__) + self.__signatures__)
return hash((type(self), self.__params_type__, *self.__signatures__))

def __eq__(self, other):
return (
Expand Down Expand Up @@ -437,7 +437,7 @@ def __eq__(self, other):
)

def __hash__(self):
return hash((type(self),) + self.fields + self.types)
return hash((type(self), *self.fields, *self.types))

def generate_struct_name(self):
# This method tries to generate an unique name for the current instance.
Expand Down Expand Up @@ -807,7 +807,7 @@ def c_support_code(self, **kwargs):
)
)

return sorted(c_support_code_set) + [final_struct_code]
return [*sorted(c_support_code_set), final_struct_code]

def c_code_cache_version(self):
return ((3,), tuple(t.c_code_cache_version() for t in self.types))
Expand Down
Loading