From dd5cc644644be630ff3c17b881f9d81139e3a39c Mon Sep 17 00:00:00 2001 From: NimaSarajpoor Date: Wed, 22 Jan 2025 00:02:25 -0500 Subject: [PATCH] Add code to check for hardcoded fastmath flags --- fastmath.py | 125 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 104 insertions(+), 21 deletions(-) diff --git a/fastmath.py b/fastmath.py index b6fea39af..63c7d239b 100755 --- a/fastmath.py +++ b/fastmath.py @@ -6,20 +6,21 @@ import pathlib -def get_njit_funcs(pkg_dir): +def get_func_nodes(pkg_dir): """ - Identify all njit functions + Retrun a dictionary where the keys are the module names and the values are + the function AST nodes Parameters ---------- pkg_dir : str - The path to the directory containing some .py files + The path to the directory containing some .py files Returns ------- - njit_funcs : list - A list of all njit functions, where each element is a tuple of the form - (module_name, func_name) + out : dict + A dictionary where the keys are the module names and the values are a list of + AST nodes for each njit function in the module """ ignore_py_files = ["__init__", "__pycache__"] pkg_dir = pathlib.Path(pkg_dir) @@ -29,29 +30,56 @@ def get_njit_funcs(pkg_dir): if fname.stem not in ignore_py_files and not fname.stem.startswith("."): module_names.append(fname.stem) - njit_funcs = [] + out = {} for module_name in module_names: filepath = pkg_dir / f"{module_name}.py" file_contents = "" with open(filepath, encoding="utf8") as f: file_contents = f.read() module = ast.parse(file_contents) + + module_funcs_nodes = [] for node in module.body: if isinstance(node, ast.FunctionDef): - func_name = node.name - for decorator in node.decorator_list: - decorator_name = None - if isinstance(decorator, ast.Name): - # Bare decorator - decorator_name = decorator.id - if isinstance(decorator, ast.Call) and isinstance( - decorator.func, ast.Name - ): - # Decorator is a function - decorator_name = decorator.func.id - - if decorator_name == "njit": - njit_funcs.append((module_name, func_name)) + module_funcs_nodes.append(node) + out[module_name] = module_funcs_nodes + + return out + + +def get_njit_funcs(pkg_dir): + """ + Identify all njit functions + + Parameters + ---------- + pkg_dir : str + The path to the directory containing some .py files + + Returns + ------- + njit_funcs : list + A list of all njit functions, where each element is a tuple of the form + (module_name, func_name) + """ + njit_funcs = [] + modules_funcs_nodes = get_func_nodes(pkg_dir) + for module_name, func_nodes in modules_funcs_nodes.items(): + for node in func_nodes: + func_name = node.name + for decorator in node.decorator_list: + decorator_name = None + if isinstance(decorator, ast.Name): + # Bare decorator + decorator_name = decorator.id + if isinstance(decorator, ast.Call) and isinstance( + decorator.func, ast.Name + ): + # Decorator is a function + decorator_name = decorator.func.id + + if decorator_name == "njit": + njit_funcs.append((module_name, func_name)) return njit_funcs @@ -89,6 +117,60 @@ def check_fastmath(pkg_dir, pkg_name): return +def check_hardcoded_flag(pkg_dir, pkg_name): + """ + Check if all `fastmath` flags are set to a config variable + + Parameters + ---------- + pkg_dir : str + The path to the directory containing some .py files + + pkg_name : str + The name of the package + + Returns + ------- + None + """ + ignore = [("fastmath", "_add_assoc")] + + hardcoded_fastmath = [] # list of njit functions with hardcoded fastmath flags + modules_funcs_nodes = get_func_nodes(pkg_dir) + for module_name, func_nodes in modules_funcs_nodes.items(): + for node in func_nodes: + if (module_name, node.name) in ignore: + continue + + njit_decorator_func = None + for decorator in node.decorator_list: + if ( + isinstance(decorator, ast.Call) + and isinstance(decorator.func, ast.Name) + and decorator.func.id == "njit" + ): + njit_decorator_func = decorator + break + + if njit_decorator_func is None: + continue + + for kwrd in njit_decorator_func.keywords: + if kwrd.arg == "fastmath": + value = kwrd.value.value + if not hasattr(value, "id") or value.id != "config": + hardcoded_fastmath.append(f"{module_name}.{node.name}") + + if len(hardcoded_fastmath) > 0: + msg = ( + "Found one or more `@njit()` functions with hardcoded `fastmath` flag. " + + f"The functions are:\n {hardcoded_fastmath}\n" + ) + raise ValueError(msg) + + return + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--check", dest="pkg_dir") @@ -98,3 +180,4 @@ def check_fastmath(pkg_dir, pkg_name): pkg_dir = pathlib.Path(args.pkg_dir) pkg_name = pkg_dir.name check_fastmath(str(pkg_dir), pkg_name) + check_hardcoded_flag(str(pkg_dir), pkg_name)