Skip to content

Commit c707ad9

Browse files
Fixed #708 Check fastmath flags (#1068)
* empty commit * fix fastmath for aamp._compute_diagonal * fix fastmath for aamp._aamp * fix fastmath for core._calculate_squared_distance_profile * fix fastmath for core.calculate_distance_profile * fix fastmath for core._apply_exclusion_zone * fix fastmath for mstump._compute_multi_D * fix fastmath for scraamp._compute_PI * fix fastmath for scraamp._prescraamp * fix fastmath for scrump._compute_PI * fix fastmath for scrump._prescrump * fix fastmath for stump._compute_diagonal * fix fastmath for stump._stump * temp commit * fix fastmath for maamp._compute_multi_p_norm * Add note to docstring for case p=np.inf * deleted wrong file * Add check for fastmath flags of callstacks * minor changes * minor changes and fixes * fix black and flake8 * minor changes * fixed typo and add comment * minor changes
1 parent d5e7607 commit c707ad9

File tree

8 files changed

+403
-13
lines changed

8 files changed

+403
-13
lines changed

fastmath.py

Lines changed: 386 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,391 @@ def check_fastmath(pkg_dir, pkg_name):
8989
return
9090

9191

92+
class FunctionCallVisitor(ast.NodeVisitor):
93+
"""
94+
A class to traverse the AST of the modules of a package to collect
95+
the call stacks of njit functions.
96+
97+
Parameters
98+
----------
99+
pkg_dir : str
100+
The path to the package directory containing some .py files.
101+
102+
pkg_name : str
103+
The name of the package.
104+
105+
Attributes
106+
----------
107+
module_names : list
108+
A list of module names to track the modules as the visitor traverses them.
109+
110+
call_stack : list
111+
A list of njit functions, representing a chain of function calls,
112+
where each element is a string of the form "module_name.func_name".
113+
114+
out : list
115+
A list of unique `call_stack`s.
116+
117+
njit_funcs : list
118+
A list of all njit functions in `pkg_dir`'s modules. Each element is a tuple
119+
of the form `(module_name, func_name)`.
120+
121+
njit_modules : set
122+
A set that contains the names of all modules, each of which contains at least
123+
one njit function.
124+
125+
njit_nodes : dict
126+
A dictionary mapping njit function names to their corresponding AST nodes.
127+
A key is a string, and it is of the form "module_name.func_name", and its
128+
corresponding value is the AST node- with type ast.FunctionDef- of that
129+
function.
130+
131+
ast_modules : dict
132+
A dictionary mapping module names to their corresponding AST objects. A key
133+
is the name of a module, and its corresponding value is the content of that
134+
module as an AST object.
135+
136+
Methods
137+
-------
138+
push_module(module_name)
139+
Push the name of a module onto the stack `module_names`.
140+
141+
pop_module()
142+
Pop the last module name from the stack `module_names`.
143+
144+
push_call_stack(module_name, func_name)
145+
Push a function call onto the stack of function calls, `call_stack`.
146+
147+
pop_call_stack()
148+
Pop the last function call from the stack of function calls, `call_stack`
149+
150+
goto_deeper_func(node)
151+
Calls the visit method from class `ast.NodeVisitor` on all children of
152+
the `node`.
153+
154+
goto_next_func(node)
155+
Calls the visit method from class `ast.NodeVisitor` on all children of
156+
the `node`.
157+
158+
push_out()
159+
Push the current function call stack, `call_stack`, onto the output list, `out`,
160+
unless it is already included in one of the so-far-collected call stacks.
161+
162+
visit_Call(node)
163+
This method is called when the visitor encounters a function call in the AST. It
164+
checks if the called function is a njit function and, if so, traverses its AST
165+
to collect its call stack.
166+
"""
167+
168+
def __init__(self, pkg_dir, pkg_name):
169+
"""
170+
Initialize the FunctionCallVisitor class. This method sets up the necessary
171+
attributes and prepares the visitor for traversing the AST of STUMPY's modules.
172+
173+
Parameters
174+
----------
175+
pkg_dir : str
176+
The path to the package directory containing some .py files.
177+
178+
pkg_name : str
179+
The name of the package.
180+
181+
Returns
182+
-------
183+
None
184+
"""
185+
super().__init__()
186+
self.module_names = []
187+
self.call_stack = []
188+
self.out = []
189+
190+
# Setup lists, dicts, and ast objects
191+
self.njit_funcs = get_njit_funcs(pkg_dir)
192+
self.njit_modules = set(mod_name for mod_name, func_name in self.njit_funcs)
193+
self.njit_nodes = {}
194+
self.ast_modules = {}
195+
196+
filepaths = sorted(f for f in pathlib.Path(pkg_dir).iterdir() if f.is_file())
197+
ignore = ["__init__.py", "__pycache__"]
198+
199+
for filepath in filepaths:
200+
file_name = filepath.name
201+
if (
202+
file_name not in ignore
203+
and not file_name.startswith("gpu")
204+
and str(filepath).endswith(".py")
205+
):
206+
module_name = file_name.replace(".py", "")
207+
file_contents = ""
208+
with open(filepath, encoding="utf8") as f:
209+
file_contents = f.read()
210+
self.ast_modules[module_name] = ast.parse(file_contents)
211+
212+
for node in self.ast_modules[module_name].body:
213+
if isinstance(node, ast.FunctionDef):
214+
func_name = node.name
215+
if (module_name, func_name) in self.njit_funcs:
216+
self.njit_nodes[f"{module_name}.{func_name}"] = node
217+
218+
def push_module(self, module_name):
219+
"""
220+
Push a module name onto the stack of module names.
221+
222+
Parameters
223+
----------
224+
module_name : str
225+
The name of the module to be pushed onto the stack.
226+
227+
Returns
228+
-------
229+
None
230+
"""
231+
self.module_names.append(module_name)
232+
233+
return
234+
235+
def pop_module(self):
236+
"""
237+
Pop the last module name from the stack of module names.
238+
239+
Parameters
240+
----------
241+
None
242+
243+
Returns
244+
-------
245+
None
246+
"""
247+
if self.module_names:
248+
self.module_names.pop()
249+
250+
return
251+
252+
def push_call_stack(self, module_name, func_name):
253+
"""
254+
Push a function call onto the stack of function calls.
255+
256+
Parameters
257+
----------
258+
module_name : str
259+
A module's name
260+
261+
func_name : str
262+
A function's name
263+
264+
Returns
265+
-------
266+
None
267+
"""
268+
self.call_stack.append(f"{module_name}.{func_name}")
269+
270+
return
271+
272+
def pop_call_stack(self):
273+
"""
274+
Pop the last function call from the stack of function calls.
275+
276+
Parameters
277+
----------
278+
None
279+
280+
Returns
281+
-------
282+
None
283+
"""
284+
if self.call_stack:
285+
self.call_stack.pop()
286+
287+
return
288+
289+
def goto_deeper_func(self, node):
290+
"""
291+
Calls the visit method from class `ast.NodeVisitor` on
292+
all children of the `node`.
293+
294+
Parameters
295+
----------
296+
node : ast.AST
297+
The AST node to be visited.
298+
299+
Returns
300+
-------
301+
None
302+
"""
303+
self.generic_visit(node)
304+
305+
return
306+
307+
def goto_next_func(self, node):
308+
"""
309+
Calls the visit method from class `ast.NodeVisitor` on
310+
all children of the node.
311+
312+
Parameters
313+
----------
314+
node : ast.AST
315+
The AST node to be visited.
316+
317+
Returns
318+
-------
319+
None
320+
"""
321+
self.generic_visit(node)
322+
323+
return
324+
325+
def push_out(self):
326+
"""
327+
Push the current function call stack onto the output list unless it
328+
is already included in one of the so-far-collected call stacks.
329+
330+
331+
Parameters
332+
----------
333+
None
334+
335+
Returns
336+
-------
337+
None
338+
"""
339+
unique = True
340+
for cs in self.out:
341+
if " ".join(self.call_stack) in " ".join(cs):
342+
unique = False
343+
break
344+
345+
if unique:
346+
self.out.append(self.call_stack.copy())
347+
348+
return
349+
350+
def visit_Call(self, node):
351+
"""
352+
Called when visiting an AST node of type `ast.Call`.
353+
354+
Parameters
355+
----------
356+
node : ast.Call
357+
The AST node representing a function call.
358+
359+
Returns
360+
-------
361+
None
362+
"""
363+
callee_name = ast.unparse(node.func)
364+
365+
module_changed = False
366+
if "." in callee_name:
367+
new_module_name, new_func_name = callee_name.split(".")[:2]
368+
369+
if new_module_name in self.njit_modules:
370+
self.push_module(new_module_name)
371+
module_changed = True
372+
else:
373+
if self.module_names:
374+
new_module_name = self.module_names[-1]
375+
new_func_name = callee_name
376+
callee_name = f"{new_module_name}.{new_func_name}"
377+
378+
if callee_name in self.njit_nodes.keys():
379+
callee_node = self.njit_nodes[callee_name]
380+
self.push_call_stack(new_module_name, new_func_name)
381+
self.goto_deeper_func(callee_node)
382+
self.push_out()
383+
self.pop_call_stack()
384+
if module_changed:
385+
self.pop_module()
386+
387+
self.goto_next_func(node)
388+
389+
return
390+
391+
392+
def get_njit_call_stacks(pkg_dir, pkg_name):
393+
"""
394+
Get the call stacks of all njit functions in `pkg_dir`
395+
396+
Parameters
397+
----------
398+
pkg_dir : str
399+
The path to the package directory containing some .py files
400+
401+
pkg_name : str
402+
The name of the package
403+
404+
Returns
405+
-------
406+
out : list
407+
A list of unique function call stacks. Each item is of type list,
408+
representing a chain of function calls.
409+
"""
410+
visitor = FunctionCallVisitor(pkg_dir, pkg_name)
411+
412+
for module_name in visitor.njit_modules:
413+
visitor.push_module(module_name)
414+
415+
for node in visitor.ast_modules[module_name].body:
416+
if isinstance(node, ast.FunctionDef):
417+
func_name = node.name
418+
if (module_name, func_name) in visitor.njit_funcs:
419+
visitor.push_call_stack(module_name, func_name)
420+
visitor.visit(node)
421+
visitor.pop_call_stack()
422+
423+
visitor.pop_module()
424+
425+
return visitor.out
426+
427+
428+
def check_call_stack_fastmath(pkg_dir, pkg_name):
429+
"""
430+
Check if all njit functions in a call stack have the same `fastmath` flag.
431+
This function raises a ValueError if it finds any inconsistencies in the
432+
`fastmath` flags in at lease one call stack of njit functions.
433+
434+
Parameters
435+
----------
436+
pkg_dir : str
437+
The path to the directory containing some .py files
438+
439+
pkg_name : str
440+
The name of the package
441+
442+
Returns
443+
-------
444+
None
445+
"""
446+
# List of call stacks with inconsistent fastmath flags
447+
inconsistent_call_stacks = []
448+
449+
njit_call_stacks = get_njit_call_stacks(pkg_dir, pkg_name)
450+
for cs in njit_call_stacks:
451+
# Set the fastmath flag of the first function in the call stack
452+
# as the reference flag
453+
module_name, func_name = cs[0].split(".")
454+
module = importlib.import_module(f".{module_name}", package="stumpy")
455+
func = getattr(module, func_name)
456+
flag_ref = func.targetoptions["fastmath"]
457+
458+
for item in cs[1:]:
459+
module_name, func_name = cs[0].split(".")
460+
module = importlib.import_module(f".{module_name}", package="stumpy")
461+
func = getattr(module, func_name)
462+
flag = func.targetoptions["fastmath"]
463+
if flag != flag_ref:
464+
inconsistent_call_stacks.append(cs)
465+
break
466+
467+
if len(inconsistent_call_stacks) > 0:
468+
msg = (
469+
"Found at least one call stack that has inconsistent `fastmath` flags. "
470+
+ f"Those call stacks are:\n {inconsistent_call_stacks}\n"
471+
)
472+
raise ValueError(msg)
473+
474+
return
475+
476+
92477
if __name__ == "__main__":
93478
parser = argparse.ArgumentParser()
94479
parser.add_argument("--check", dest="pkg_dir")
@@ -98,3 +483,4 @@ def check_fastmath(pkg_dir, pkg_name):
98483
pkg_dir = pathlib.Path(args.pkg_dir)
99484
pkg_name = pkg_dir.name
100485
check_fastmath(str(pkg_dir), pkg_name)
486+
check_call_stack_fastmath(str(pkg_dir), pkg_name)

0 commit comments

Comments
 (0)