Skip to content

Slight cleanup of the ruff configuration #616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ omit = [
"pytensor/graph/unify.py",
"pytensor/link/jax/jax_linker.py",
"pytensor/link/jax/jax_dispatch.py",
"pytensor/graph/toolbox.py",
"pytensor/scalar/basic_scipy.py",
]
branch = true
relative_files = true
Expand Down Expand Up @@ -132,7 +130,7 @@ disable = ["C0330", "C0326"]

[tool.ruff]
select = ["C", "E", "F", "I", "UP", "W"]
ignore = ["C408", "C901", "E501", "E741", "UP031"]
ignore = ["C408", "C901", "E501", "E741"]
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]


Expand All @@ -142,9 +140,7 @@ lines-after-imports = 2
[tool.ruff.per-file-ignores]
# TODO: Get rid of these:
"**/__init__.py" = ["F401", "E402", "F403"]
"pytensor/tensor/linalg.py" = ["F401", "F403"]
"pytensor/scalar/basic_scipy.py" = ["E402"]
"pytensor/graph/toolbox.py" = ["E402"]
"pytensor/tensor/linalg.py" = ["F403"]
# For the tests we skip because `pytest.importorskip` is used:
"tests/link/jax/test_scalar.py" = ["E402"]
"tests/link/jax/test_tensor_basic.py" = ["E402"]
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ def __hash__(self):
def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
is_inline = self.is_inline
return "%(name)s{inline=%(is_inline)s}" % locals()
return "{name}{{inline={is_inline}}}".format(**locals())

@config.change_flags(compute_test_value="off")
def _recompute_lop_op(self):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/destroyhandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def fast_inplace_check(fgraph, inputs):
return inputs


class DestroyHandler(Bookkeeper): # noqa
class DestroyHandler(Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate
because of aliasing and destructive operations.
Expand Down
120 changes: 57 additions & 63 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,50 +250,46 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# block failed.
do_return = (
"""
if (%(failure_var)s) {
do_return = """
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// in __ERROR.
PyObject* err_type = NULL;
PyObject* err_msg = NULL;
PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) {err_type = Py_None;Py_INCREF(Py_None);}
if (!err_msg) {err_msg = Py_None; Py_INCREF(Py_None);}
if (!err_traceback) {err_traceback = Py_None; Py_INCREF(Py_None);}
if (!err_type) {{err_type = Py_None;Py_INCREF(Py_None);}}
if (!err_msg) {{err_msg = Py_None; Py_INCREF(Py_None);}}
if (!err_traceback) {{err_traceback = Py_None; Py_INCREF(Py_None);}}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 2, err_traceback);
{Py_XDECREF(old_err_type);}
{Py_XDECREF(old_err_msg);}
{Py_XDECREF(old_err_traceback);}
}
{{Py_XDECREF(old_err_type);}}
{{Py_XDECREF(old_err_msg);}}
{{Py_XDECREF(old_err_traceback);}}
}}
// The failure code is returned to index what code block failed.
return %(failure_var)s;
"""
% sub
)
return {failure_var};
""".format(**sub)

sub = dict(sub)
sub.update(locals())

# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.

struct_code = (
"""
namespace {
struct %(name)s {
struct_code = """
namespace {{
struct {name} {{
PyObject* __ERROR;

%(storage_decl)s
%(struct_decl)s
{storage_decl}
{struct_decl}

%(name)s() {
{name}() {{
// This is only somewhat safe because we:
// 1) Are not a virtual class
// 2) Do not use any virtual classes in the members
Expand All @@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
#ifndef PYTENSOR_DONT_MEMSET_STRUCT
memset(this, 0, sizeof(*this));
#endif
}
~%(name)s(void) {
}}
~{name}(void) {{
cleanup();
}
}}

int init(PyObject* __ERROR, %(args_decl)s) {
%(storage_incref)s
%(storage_set)s
%(struct_init_head)s
int init(PyObject* __ERROR, {args_decl}) {{
{storage_incref}
{storage_set}
{struct_init_head}
this->__ERROR = __ERROR;
return 0;
}
void cleanup(void) {
%(struct_cleanup)s
%(storage_decref)s
}
int run(void) {
int %(failure_var)s = 0;
%(behavior)s
%(do_return)s
}
};
}
"""
% sub
)
}}
void cleanup(void) {{
{struct_cleanup}
{storage_decref}
}}
int run(void) {{
int {failure_var} = 0;
{behavior}
{do_return}
}}
}};
}}
""".format(**sub)

return struct_code

Expand Down Expand Up @@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
pre = (
""
"""
py_%(name)s = Py_None;
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = Py_None;
{{Py_XINCREF(py_{name});}}
""".format(**locals())
)
return pre + r.type.c_init(name, sub)

Expand Down Expand Up @@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
c_extract = r.type.c_extract(name, sub, False)

pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
return pre + c_extract


Expand All @@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)

pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
return pre + c_extract


Expand All @@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):

"""
post = """
{Py_XDECREF(py_%(name)s);}
""" % locals()
{{Py_XDECREF(py_{name});}}
""".format(**locals())
return r.type.c_cleanup(name, sub) + post


Expand All @@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):

"""
return """
if (!%(failure_var)s) {
%(sync)s
PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
{Py_XDECREF(old);}
}
""" % dict(sync=r.type.c_sync(name, sub), name=name, **sub)
if (!{failure_var}) {{
{sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}}
}}
""".format(**dict(sync=r.type.c_sync(name, sub), name=name, **sub))


def apply_policy(fgraph, policy, r, name, sub):
Expand Down Expand Up @@ -1724,7 +1718,7 @@ class _CThunk:

def __init__(self, cthunk, init_tasks, tasks, error_storage, module):
# Lazy import to avoid compilation when importing pytensor.
from pytensor.link.c.cutils import run_cthunk # noqa
from pytensor.link.c.cutils import run_cthunk

self.run_cthunk = run_cthunk
self.cthunk = cthunk
Expand Down
11 changes: 5 additions & 6 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1950,14 +1950,13 @@ def _try_flags(

code = (
"""
%(preamble)s
{preamble}
int main(int argc, char** argv)
{
%(body)s
{{
{body}
return 0;
}
"""
% locals()
}}
""".format(**locals())
).encode()
return cls._try_compile_tmp(
code,
Expand Down
24 changes: 13 additions & 11 deletions pytensor/link/c/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,18 +558,20 @@ def c_extract_out(

"""
return """
if (py_%(name)s == Py_None)
{
%(c_init_code)s
}
if (py_{name} == Py_None)
{{
{c_init_code}
}}
else
{
%(c_extract_code)s
}
""" % dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
{{
{c_extract_code}
}}
""".format(
**dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
)
)

def c_cleanup(self, name: str, sub: dict[str, str]) -> str:
Expand Down
4 changes: 2 additions & 2 deletions pytensor/link/c/lazylinker_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
def try_import():
global lazylinker_ext
sys.path[0:0] = [config.compiledir]
import lazylinker_ext # noqa
import lazylinker_ext

del sys.path[0]

Expand Down Expand Up @@ -167,4 +167,4 @@ def try_reload():
from lazylinker_ext.lazylinker_ext import CLazyLinker, get_version # noqa
from lazylinker_ext.lazylinker_ext import * # noqa

assert force_compile or (version == get_version()) # noqa
assert force_compile or (version == get_version())
30 changes: 16 additions & 14 deletions pytensor/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,20 +596,22 @@ def c_code(self, node, name, inp, out, sub):

# Generate the C code
return """
%(define_macros)s
{
if (%(func_name)s(%(func_args)s%(params)s) != 0) {
%(fail)s
}
}
%(undef_macros)s
""" % dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
{define_macros}
{{
if ({func_name}({func_args}{params}) != 0) {{
{fail}
}}
}}
{undef_macros}
""".format(
**dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
)
)
else:
if "code" in self.code_sections:
Expand Down
Loading