Skip to content

Commit

Permalink
Add code to check for hardcoded fastmath flags
Browse files Browse the repository at this point in the history
  • Loading branch information
NimaSarajpoor committed Jan 22, 2025
1 parent 2369e33 commit dd5cc64
Showing 1 changed file with 104 additions and 21 deletions.
125 changes: 104 additions & 21 deletions fastmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,21 @@
import pathlib


def get_njit_funcs(pkg_dir):
def get_func_nodes(pkg_dir):
"""
Identify all njit functions
Retrun a dictionary where the keys are the module names and the values are
the function AST nodes
Parameters
----------
pkg_dir : str
The path to the directory containing some .py files
The path to the directory containing some .py files
Returns
-------
njit_funcs : list
A list of all njit functions, where each element is a tuple of the form
(module_name, func_name)
out : dict
A dictionary where the keys are the module names and the values are a list of
AST nodes for each njit function in the module
"""
ignore_py_files = ["__init__", "__pycache__"]
pkg_dir = pathlib.Path(pkg_dir)
Expand All @@ -29,29 +30,56 @@ def get_njit_funcs(pkg_dir):
if fname.stem not in ignore_py_files and not fname.stem.startswith("."):
module_names.append(fname.stem)

njit_funcs = []
out = {}
for module_name in module_names:
filepath = pkg_dir / f"{module_name}.py"
file_contents = ""
with open(filepath, encoding="utf8") as f:
file_contents = f.read()
module = ast.parse(file_contents)

module_funcs_nodes = []
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))
module_funcs_nodes.append(node)
out[module_name] = module_funcs_nodes

return out


def get_njit_funcs(pkg_dir):
"""
Identify all njit functions
Parameters
----------
pkg_dir : str
The path to the directory containing some .py files
Returns
-------
njit_funcs : list
A list of all njit functions, where each element is a tuple of the form
(module_name, func_name)
"""
njit_funcs = []
modules_funcs_nodes = get_func_nodes(pkg_dir)
for module_name, func_nodes in modules_funcs_nodes.items():
for node in func_nodes:
func_name = node.name
for decorator in node.decorator_list:
decorator_name = None
if isinstance(decorator, ast.Name):
# Bare decorator
decorator_name = decorator.id
if isinstance(decorator, ast.Call) and isinstance(
decorator.func, ast.Name
):
# Decorator is a function
decorator_name = decorator.func.id

if decorator_name == "njit":
njit_funcs.append((module_name, func_name))

return njit_funcs

Expand Down Expand Up @@ -89,6 +117,60 @@ def check_fastmath(pkg_dir, pkg_name):
return


def check_hardcoded_flag(pkg_dir, pkg_name):
"""
Check if all `fastmath` flags are set to a config variable
Parameters
----------
pkg_dir : str
The path to the directory containing some .py files
pkg_name : str
The name of the package
Returns
-------
None
"""
ignore = [("fastmath", "_add_assoc")]

hardcoded_fastmath = [] # list of njit functions with hardcoded fastmath flags
modules_funcs_nodes = get_func_nodes(pkg_dir)
for module_name, func_nodes in modules_funcs_nodes.items():
for node in func_nodes:
if (module_name, node.name) in ignore:
continue

njit_decorator_func = None
for decorator in node.decorator_list:
if (
isinstance(decorator, ast.Call)
and isinstance(decorator.func, ast.Name)
and decorator.func.id == "njit"
):
njit_decorator_func = decorator
break

if njit_decorator_func is None:
continue

for kwrd in njit_decorator_func.keywords:
if kwrd.arg == "fastmath":
value = kwrd.value.value
if not hasattr(value, "id") or value.id != "config":
hardcoded_fastmath.append(f"{module_name}.{node.name}")

if len(hardcoded_fastmath) > 0:
msg = (
"Found one or more `@njit()` functions with hardcoded `fastmath` flag. "
+ f"The functions are:\n {hardcoded_fastmath}\n"
)
raise ValueError(msg)

return


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--check", dest="pkg_dir")
Expand All @@ -98,3 +180,4 @@ def check_fastmath(pkg_dir, pkg_name):
pkg_dir = pathlib.Path(args.pkg_dir)
pkg_name = pkg_dir.name
check_fastmath(str(pkg_dir), pkg_name)
check_hardcoded_flag(str(pkg_dir), pkg_name)

0 comments on commit dd5cc64

Please sign in to comment.