6
6
import pathlib
7
7
8
8
9
- def get_njit_funcs (pkg_dir ):
9
+ def get_func_nodes (pkg_dir ):
10
10
"""
11
- Identify all njit functions
11
+ Retrun a dictionary where the keys are the module names and the values are
12
+ the function AST nodes
12
13
13
14
Parameters
14
15
----------
15
16
pkg_dir : str
16
- The path to the directory containing some .py files
17
+ The path to the directory containing some .py files
17
18
18
19
Returns
19
20
-------
20
- njit_funcs : list
21
- A list of all njit functions, where each element is a tuple of the form
22
- (module_name, func_name)
21
+ out : dict
22
+ A dictionary where the keys are the module names and the values are a list of
23
+ AST nodes for each njit function in the module
23
24
"""
24
25
ignore_py_files = ["__init__" , "__pycache__" ]
25
26
pkg_dir = pathlib .Path (pkg_dir )
@@ -29,29 +30,56 @@ def get_njit_funcs(pkg_dir):
29
30
if fname .stem not in ignore_py_files and not fname .stem .startswith ("." ):
30
31
module_names .append (fname .stem )
31
32
32
- njit_funcs = []
33
+ out = {}
33
34
for module_name in module_names :
34
35
filepath = pkg_dir / f"{ module_name } .py"
35
36
file_contents = ""
36
37
with open (filepath , encoding = "utf8" ) as f :
37
38
file_contents = f .read ()
38
39
module = ast .parse (file_contents )
40
+
41
+ module_funcs_nodes = []
39
42
for node in module .body :
40
43
if isinstance (node , ast .FunctionDef ):
41
- func_name = node .name
42
- for decorator in node .decorator_list :
43
- decorator_name = None
44
- if isinstance (decorator , ast .Name ):
45
- # Bare decorator
46
- decorator_name = decorator .id
47
- if isinstance (decorator , ast .Call ) and isinstance (
48
- decorator .func , ast .Name
49
- ):
50
- # Decorator is a function
51
- decorator_name = decorator .func .id
52
-
53
- if decorator_name == "njit" :
54
- njit_funcs .append ((module_name , func_name ))
44
+ module_funcs_nodes .append (node )
45
+ out [module_name ] = module_funcs_nodes
46
+
47
+ return out
48
+
49
+
50
+ def get_njit_funcs (pkg_dir ):
51
+ """
52
+ Identify all njit functions
53
+
54
+ Parameters
55
+ ----------
56
+ pkg_dir : str
57
+ The path to the directory containing some .py files
58
+
59
+ Returns
60
+ -------
61
+ njit_funcs : list
62
+ A list of all njit functions, where each element is a tuple of the form
63
+ (module_name, func_name)
64
+ """
65
+ njit_funcs = []
66
+ modules_funcs_nodes = get_func_nodes (pkg_dir )
67
+ for module_name , func_nodes in modules_funcs_nodes .items ():
68
+ for node in func_nodes :
69
+ func_name = node .name
70
+ for decorator in node .decorator_list :
71
+ decorator_name = None
72
+ if isinstance (decorator , ast .Name ):
73
+ # Bare decorator
74
+ decorator_name = decorator .id
75
+ if isinstance (decorator , ast .Call ) and isinstance (
76
+ decorator .func , ast .Name
77
+ ):
78
+ # Decorator is a function
79
+ decorator_name = decorator .func .id
80
+
81
+ if decorator_name == "njit" :
82
+ njit_funcs .append ((module_name , func_name ))
55
83
56
84
return njit_funcs
57
85
@@ -89,6 +117,60 @@ def check_fastmath(pkg_dir, pkg_name):
89
117
return
90
118
91
119
120
+ def check_hardcoded_flag (pkg_dir , pkg_name ):
121
+ """
122
+ Check if all `fastmath` flags are set to a config variable
123
+
124
+ Parameters
125
+ ----------
126
+ pkg_dir : str
127
+ The path to the directory containing some .py files
128
+
129
+ pkg_name : str
130
+ The name of the package
131
+
132
+ Returns
133
+ -------
134
+ None
135
+ """
136
+ ignore = [("fastmath" , "_add_assoc" )]
137
+
138
+ hardcoded_fastmath = [] # list of njit functions with hardcoded fastmath flags
139
+ modules_funcs_nodes = get_func_nodes (pkg_dir )
140
+ for module_name , func_nodes in modules_funcs_nodes .items ():
141
+ for node in func_nodes :
142
+ if (module_name , node .name ) in ignore :
143
+ continue
144
+
145
+ njit_decorator_func = None
146
+ for decorator in node .decorator_list :
147
+ if (
148
+ isinstance (decorator , ast .Call )
149
+ and isinstance (decorator .func , ast .Name )
150
+ and decorator .func .id == "njit"
151
+ ):
152
+ njit_decorator_func = decorator
153
+ break
154
+
155
+ if njit_decorator_func is None :
156
+ continue
157
+
158
+ for kwrd in njit_decorator_func .keywords :
159
+ if kwrd .arg == "fastmath" :
160
+ value = kwrd .value .value
161
+ if not hasattr (value , "id" ) or value .id != "config" :
162
+ hardcoded_fastmath .append (f"{ module_name } .{ node .name } " )
163
+
164
+ if len (hardcoded_fastmath ) > 0 :
165
+ msg = (
166
+ "Found one or more `@njit()` functions with hardcoded `fastmath` flag. "
167
+ + f"The functions are:\n { hardcoded_fastmath } \n "
168
+ )
169
+ raise ValueError (msg )
170
+
171
+ return
172
+
173
+
92
174
if __name__ == "__main__" :
93
175
parser = argparse .ArgumentParser ()
94
176
parser .add_argument ("--check" , dest = "pkg_dir" )
@@ -98,3 +180,4 @@ def check_fastmath(pkg_dir, pkg_name):
98
180
pkg_dir = pathlib .Path (args .pkg_dir )
99
181
pkg_name = pkg_dir .name
100
182
check_fastmath (str (pkg_dir ), pkg_name )
183
+ check_hardcoded_flag (str (pkg_dir ), pkg_name )
0 commit comments