Skip to content

⚡️ Speed up method InjectPerfOnly.visit_FunctionDef by 24% in PR #363 (part-1-windows-fixes) #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: part-1-windows-fixes
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 104 additions & 42 deletions codeflash/code_utils/instrument_existing_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import ast
from collections.abc import Iterable
from pathlib import Path
from typing import TYPE_CHECKING

Expand All @@ -9,7 +10,7 @@
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
from codeflash.models.models import FunctionParent, TestingMode, VerificationType
from codeflash.models.models import CodePosition, FunctionParent, TestingMode, VerificationType

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down Expand Up @@ -64,62 +65,99 @@ def __init__(
self.module_path = module_path
self.test_framework = test_framework
self.call_positions = call_positions
# Pre-cache node wrappers often instantiated
self.ast_codeflash_loop_index = ast.Name(id="codeflash_loop_index", ctx=ast.Load())
self.ast_codeflash_cur = ast.Name(id="codeflash_cur", ctx=ast.Load())
self.ast_codeflash_con = ast.Name(id="codeflash_con", ctx=ast.Load())
if len(function.parents) == 1 and function.parents[0].type == "ClassDef":
self.class_name = function.top_level_parent_name

def find_and_update_line_node(
self, test_node: ast.stmt, node_name: str, index: str, test_class_name: str | None = None
) -> Iterable[ast.stmt] | None:
# Optimize: Inline self._in_call_position and cache .func once
call_node = None
behavior_mode = self.mode == TestingMode.BEHAVIOR
function_object_name = self.function_object.function_name
function_qualified_name = self.function_object.qualified_name
module_path_const = ast.Constant(value=self.module_path)
test_class_const = ast.Constant(value=test_class_name or None)
node_name_const = ast.Constant(value=node_name)
qualified_name_const = ast.Constant(value=function_qualified_name)
index_const = ast.Constant(value=index)
args_behavior = [self.ast_codeflash_cur, self.ast_codeflash_con] if behavior_mode else []

for node in ast.walk(test_node):
if isinstance(node, ast.Call) and node_in_call_position(node, self.call_positions):
call_node = node
if isinstance(node.func, ast.Name):
function_name = node.func.id
# Fast path: check for Call nodes only
if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")):
continue
# Inline node_in_call_position logic (from profiler hotspot)
node_lineno = getattr(node, "lineno", None)
node_col_offset = getattr(node, "col_offset", None)
node_end_lineno = getattr(node, "end_lineno", None)
node_end_col_offset = getattr(node, "end_col_offset", None)
found = False
for pos in self.call_positions:
pos_line = pos.line_no
if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno:
if pos_line == node_lineno and node_col_offset <= pos.col_no:
found = True
break
if (
pos_line == node_end_lineno
and node_end_col_offset is not None
and node_end_col_offset >= pos.col_no
):
found = True
break
if node_lineno < pos_line < node_end_lineno:
found = True
break
if not found:
continue

call_node = node
func = node.func
# Handle ast.Name fast path
if isinstance(func, ast.Name):
function_name = func.id
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
# Build ast.Name fields for use in args
codeflash_func_arg = ast.Name(id=function_name, ctx=ast.Load())
# Compose argument tuple directly, for speed
node.args = [
codeflash_func_arg,
module_path_const,
test_class_const,
node_name_const,
qualified_name_const,
index_const,
self.ast_codeflash_loop_index,
*args_behavior,
*call_node.args,
]
node.keywords = call_node.keywords
break
if isinstance(func, ast.Attribute):
# This path is almost never hit (profile), but handle it
function_to_test = func.attr
if function_to_test == function_object_name:
# NOTE: ast.unparse is very slow; only call if necessary
function_name = ast.unparse(func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[ast.Name(id="codeflash_cur", ctx=ast.Load()), ast.Name(id="codeflash_con", ctx=ast.Load())]
if self.mode == TestingMode.BEHAVIOR
else []
),
module_path_const,
test_class_const,
node_name_const,
qualified_name_const,
index_const,
self.ast_codeflash_loop_index,
*args_behavior,
*call_node.args,
]
node.keywords = call_node.keywords
break
if isinstance(node.func, ast.Attribute):
function_to_test = node.func.attr
if function_to_test == self.function_object.function_name:
function_name = ast.unparse(node.func)
node.func = ast.Name(id="codeflash_wrap", ctx=ast.Load())
node.args = [
ast.Name(id=function_name, ctx=ast.Load()),
ast.Constant(value=self.module_path),
ast.Constant(value=test_class_name or None),
ast.Constant(value=node_name),
ast.Constant(value=self.function_object.qualified_name),
ast.Constant(value=index),
ast.Name(id="codeflash_loop_index", ctx=ast.Load()),
*(
[
ast.Name(id="codeflash_cur", ctx=ast.Load()),
ast.Name(id="codeflash_con", ctx=ast.Load()),
]
if self.mode == TestingMode.BEHAVIOR
else []
),
*call_node.args,
]
node.keywords = call_node.keywords
break

if call_node is None:
return None
return [test_node]
Expand Down Expand Up @@ -153,6 +191,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
while j >= 0:
compound_line_node: ast.stmt = line_node.body[j]
internal_node: ast.AST
# No significant hotspot here; ast.walk used on small subtrees
for internal_node in ast.walk(compound_line_node):
if isinstance(internal_node, (ast.stmt, ast.Assign)):
updated_node = self.find_and_update_line_node(
Expand Down Expand Up @@ -284,6 +323,29 @@ def visit_FunctionDef(self, node: ast.FunctionDef, test_class_name: str | None =
]
return node

def _in_call_position(self, node: ast.AST) -> bool:
# Inline node_in_call_position for performance
if not (isinstance(node, ast.Call) and hasattr(node, "lineno") and hasattr(node, "col_offset")):
return False
node_lineno = getattr(node, "lineno", None)
node_col_offset = getattr(node, "col_offset", None)
node_end_lineno = getattr(node, "end_lineno", None)
node_end_col_offset = getattr(node, "end_col_offset", None)
for pos in self.call_positions:
pos_line = pos.line_no
if pos_line is not None and node_end_lineno is not None and node_lineno <= pos_line <= node_end_lineno:
if pos_line == node_lineno and node_col_offset <= pos.col_no:
return True
if (
pos_line == node_end_lineno
and node_end_col_offset is not None
and node_end_col_offset >= pos.col_no
):
return True
if node_lineno < pos_line < node_end_lineno:
return True
return False


class FunctionImportedAsVisitor(ast.NodeVisitor):
"""Checks if a function has been imported as an alias. We only care about the alias then.
Expand Down
Loading