Skip to content

Commit dd5cc64

Browse files
committed
Add code to check for hardcoded fastmath flags
1 parent 2369e33 commit dd5cc64

File tree

1 file changed

+104
-21
lines changed

1 file changed

+104
-21
lines changed

fastmath.py

Lines changed: 104 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66
import pathlib
77

88

9-
def get_njit_funcs(pkg_dir):
9+
def get_func_nodes(pkg_dir):
1010
"""
11-
Identify all njit functions
11+
Retrun a dictionary where the keys are the module names and the values are
12+
the function AST nodes
1213
1314
Parameters
1415
----------
1516
pkg_dir : str
16-
The path to the directory containing some .py files
17+
The path to the directory containing some .py files
1718
1819
Returns
1920
-------
20-
njit_funcs : list
21-
A list of all njit functions, where each element is a tuple of the form
22-
(module_name, func_name)
21+
out : dict
22+
A dictionary where the keys are the module names and the values are a list of
23+
AST nodes for each njit function in the module
2324
"""
2425
ignore_py_files = ["__init__", "__pycache__"]
2526
pkg_dir = pathlib.Path(pkg_dir)
@@ -29,29 +30,56 @@ def get_njit_funcs(pkg_dir):
2930
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
3031
module_names.append(fname.stem)
3132

32-
njit_funcs = []
33+
out = {}
3334
for module_name in module_names:
3435
filepath = pkg_dir / f"{module_name}.py"
3536
file_contents = ""
3637
with open(filepath, encoding="utf8") as f:
3738
file_contents = f.read()
3839
module = ast.parse(file_contents)
40+
41+
module_funcs_nodes = []
3942
for node in module.body:
4043
if isinstance(node, ast.FunctionDef):
41-
func_name = node.name
42-
for decorator in node.decorator_list:
43-
decorator_name = None
44-
if isinstance(decorator, ast.Name):
45-
# Bare decorator
46-
decorator_name = decorator.id
47-
if isinstance(decorator, ast.Call) and isinstance(
48-
decorator.func, ast.Name
49-
):
50-
# Decorator is a function
51-
decorator_name = decorator.func.id
52-
53-
if decorator_name == "njit":
54-
njit_funcs.append((module_name, func_name))
44+
module_funcs_nodes.append(node)
45+
out[module_name] = module_funcs_nodes
46+
47+
return out
48+
49+
50+
def get_njit_funcs(pkg_dir):
51+
"""
52+
Identify all njit functions
53+
54+
Parameters
55+
----------
56+
pkg_dir : str
57+
The path to the directory containing some .py files
58+
59+
Returns
60+
-------
61+
njit_funcs : list
62+
A list of all njit functions, where each element is a tuple of the form
63+
(module_name, func_name)
64+
"""
65+
njit_funcs = []
66+
modules_funcs_nodes = get_func_nodes(pkg_dir)
67+
for module_name, func_nodes in modules_funcs_nodes.items():
68+
for node in func_nodes:
69+
func_name = node.name
70+
for decorator in node.decorator_list:
71+
decorator_name = None
72+
if isinstance(decorator, ast.Name):
73+
# Bare decorator
74+
decorator_name = decorator.id
75+
if isinstance(decorator, ast.Call) and isinstance(
76+
decorator.func, ast.Name
77+
):
78+
# Decorator is a function
79+
decorator_name = decorator.func.id
80+
81+
if decorator_name == "njit":
82+
njit_funcs.append((module_name, func_name))
5583

5684
return njit_funcs
5785

@@ -89,6 +117,60 @@ def check_fastmath(pkg_dir, pkg_name):
89117
return
90118

91119

120+
def check_hardcoded_flag(pkg_dir, pkg_name):
121+
"""
122+
Check if all `fastmath` flags are set to a config variable
123+
124+
Parameters
125+
----------
126+
pkg_dir : str
127+
The path to the directory containing some .py files
128+
129+
pkg_name : str
130+
The name of the package
131+
132+
Returns
133+
-------
134+
None
135+
"""
136+
ignore = [("fastmath", "_add_assoc")]
137+
138+
hardcoded_fastmath = [] # list of njit functions with hardcoded fastmath flags
139+
modules_funcs_nodes = get_func_nodes(pkg_dir)
140+
for module_name, func_nodes in modules_funcs_nodes.items():
141+
for node in func_nodes:
142+
if (module_name, node.name) in ignore:
143+
continue
144+
145+
njit_decorator_func = None
146+
for decorator in node.decorator_list:
147+
if (
148+
isinstance(decorator, ast.Call)
149+
and isinstance(decorator.func, ast.Name)
150+
and decorator.func.id == "njit"
151+
):
152+
njit_decorator_func = decorator
153+
break
154+
155+
if njit_decorator_func is None:
156+
continue
157+
158+
for kwrd in njit_decorator_func.keywords:
159+
if kwrd.arg == "fastmath":
160+
value = kwrd.value.value
161+
if not hasattr(value, "id") or value.id != "config":
162+
hardcoded_fastmath.append(f"{module_name}.{node.name}")
163+
164+
if len(hardcoded_fastmath) > 0:
165+
msg = (
166+
"Found one or more `@njit()` functions with hardcoded `fastmath` flag. "
167+
+ f"The functions are:\n {hardcoded_fastmath}\n"
168+
)
169+
raise ValueError(msg)
170+
171+
return
172+
173+
92174
if __name__ == "__main__":
93175
parser = argparse.ArgumentParser()
94176
parser.add_argument("--check", dest="pkg_dir")
@@ -98,3 +180,4 @@ def check_fastmath(pkg_dir, pkg_name):
98180
pkg_dir = pathlib.Path(args.pkg_dir)
99181
pkg_name = pkg_dir.name
100182
check_fastmath(str(pkg_dir), pkg_name)
183+
check_hardcoded_flag(str(pkg_dir), pkg_name)

0 commit comments

Comments
 (0)