Skip to content

⚡️ Speed up method InjectPerfOnly.visit_FunctionDef by 24% in PR #363 (part-1-windows-fixes) #368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: part-1-windows-fixes
Choose a base branch
from

Conversation

codeflash-ai[bot]
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Jun 22, 2025

⚡️ This pull request contains optimizations for PR #363

If you approve this dependent PR, these changes will be merged into the original PR branch part-1-windows-fixes.

This PR will be automatically closed if the original PR is merged.


📄 24% (0.24x) speedup for InjectPerfOnly.visit_FunctionDef in codeflash/code_utils/instrument_existing_tests.py

⏱️ Runtime : 5.76 milliseconds 4.65 milliseconds (best of 191 runs)

📝 Explanation and details

Here's an optimized rewrite of your original code, focusing on critical hotspots from the profiler data.

Optimization summary:

  • Inline the node_in_call_position logic directly into find_and_update_line_node to avoid repeated function call overhead for every AST node; because inner loop is extremely hot.
  • Pre-split self.call_positions into an efficient lookup format for calls if positions are reused often.
  • Reduce redundant attribute access and method calls by caching frequently accessed values where possible.
  • Move branching on the most frequent path (ast.Name) up, and short-circuit to avoid unnecessary checks.
  • Fast path for common case: ast.Name, skipping .unparse and unnecessary packing/mapping.
  • Avoid repeated ast.Name(id="codeflash_loop_index", ctx=ast.Load()) construction by storing as a field (self.ast_codeflash_loop_index etc.) (since they're repeated many times for a single method walk, re-use them).
  • Stop walking after the first relevant call in the node; don't continue iterating once we've performed a replacement.

Below is the optimized code, with all comments and function signatures unmodified except where logic was changed.

Key performance wins:

  • Hot inner loop now inlines the call position check, caches common constants, and breaks early.
  • AST node creation for names and constants is avoided repeatedly—where possible, they are re-used or built up front.
  • Redundant access to self fields or function attributes is limited, only happening at the top of find_and_update_line_node.
  • Fast path (ast.Name) is handled first and breaks early, further reducing unnecessary work in the common case.

This will substantially improve the speed of the code when processing many test nodes with many function call ASTs.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 14 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
import os
import sqlite3
import sys
from collections.abc import Iterable
from pathlib import Path
from tempfile import TemporaryDirectory

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Minimal stubs for external classes/enums used by InjectPerfOnly
class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    PERF = "PERF"

class CodePosition:
    def __init__(self, line_no, col_no, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class FunctionToOptimize:
    def __init__(self, function_name, qualified_name, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.qualified_name = qualified_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

class Parent:
    def __init__(self, type_):
        self.type = type_

# Helper for code generation
def ast_to_code(tree: ast.AST) -> str:
    """Convert AST back to code string."""
    if hasattr(ast, "unparse"):
        return ast.unparse(tree)
    # Fallback for older Python: use astor if available
    import astor
    return astor.to_source(tree)

# ------------------ UNIT TESTS ------------------

# Helper to parse code and get a FunctionDef node
def get_funcdef_node(code: str) -> ast.FunctionDef:
    tree = ast.parse(code)
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            return node
    raise ValueError("No FunctionDef found")

# Helper to find ast.Call nodes in a FunctionDef
def find_calls(node: ast.FunctionDef):
    return [n for n in ast.walk(node) if isinstance(n, ast.Call)]

# Helper to simulate call positions for a given code string
def get_call_positions(code: str) -> list[CodePosition]:
    """Find the first call in the function and return its position."""
    tree = ast.parse(code)
    for node in ast.walk(tree):
        if isinstance(node, ast.Call):
            return [CodePosition(node.lineno, node.col_offset, getattr(node, "end_col_offset", None))]
    return []

# ---- BASIC TEST CASES ----















from __future__ import annotations

import ast
import os
import sqlite3
from collections.abc import Iterable
from pathlib import Path
from tempfile import TemporaryDirectory
from types import ModuleType
from typing import Any

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.instrument_existing_tests import InjectPerfOnly


# Minimal stubs for external dependencies
class CodePosition:
    def __init__(self, line_no=None, col_no=None, end_col_offset=None):
        self.line_no = line_no
        self.col_no = col_no
        self.end_col_offset = end_col_offset

class TestingMode:
    BEHAVIOR = "BEHAVIOR"
    PERF = "PERF"

class FunctionToOptimize:
    def __init__(self, function_name, qualified_name=None, parents=None, top_level_parent_name=None):
        self.function_name = function_name
        self.qualified_name = qualified_name or function_name
        self.parents = parents or []
        self.top_level_parent_name = top_level_parent_name

# Helper functions for testing
def parse_and_visit(func_src: str, call_positions: list[CodePosition], function: FunctionToOptimize, mode=TestingMode.BEHAVIOR, test_framework="pytest", test_class_name=None):
    """
    Helper to parse a function source, run visit_FunctionDef, and return the modified AST node.
    """
    module = ast.parse(func_src)
    func_node = next(node for node in module.body if isinstance(node, ast.FunctionDef))
    visitor = InjectPerfOnly(
        function=function,
        module_path="dummy_module.py",
        test_framework=test_framework,
        call_positions=call_positions,
        mode=mode,
    )
    codeflash_output = visitor.visit_FunctionDef(func_node, test_class_name); new_node = codeflash_output
    return new_node

def get_first_call(node: ast.FunctionDef):
    """Find the first ast.Call node in the function body (recursively)."""
    for n in ast.walk(node):
        if isinstance(n, ast.Call):
            return n
    return None

# ----------------- BASIC TEST CASES -----------------


def foo():
    pass
"""
    function = FunctionToOptimize(function_name="bar")
    call_positions = []
    new_node = parse_and_visit(func_src, call_positions, function)

def test_adds_decorator_for_unittest():
    # Should add timeout_decorator.timeout(15) for unittest
    func_src = """
def test_example():
    pass
"""
    function = FunctionToOptimize(function_name="bar")
    call_positions = []
    new_node = parse_and_visit(func_src, call_positions, function, test_framework="unittest")

def test_inserts_codeflash_assignments_and_close_behavior_mode():
    # Should insert codeflash_loop_index and other assignments and close in BEHAVIOR mode
    func_src = """
def test_simple():
    foo()
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    # Simulate a call at line 2, col 4
    call_positions = [CodePosition(line_no=2, col_no=4, end_col_offset=9)]
    new_node = parse_and_visit(func_src, call_positions, function, mode=TestingMode.BEHAVIOR)
    # Should contain codeflash_con and codeflash_cur assignments
    ids = [getattr(x.targets[0], "id", None) for x in new_node.body if isinstance(x, ast.Assign)]
    expr = new_node.body[-1].value

def test_wraps_function_call_in_body():
    # Should wrap the foo() call in codeflash_wrap
    func_src = """
def test_wrap():
    foo(1, 2)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=2, col_no=4, end_col_offset=10)]
    new_node = parse_and_visit(func_src, call_positions, function)
    # Find the call, should now be codeflash_wrap(...)
    call = get_first_call(new_node)
    # Should include codeflash_cur and codeflash_con in args
    arg_ids = [a.id for a in call.args if isinstance(a, ast.Name)]

# ----------------- EDGE TEST CASES -----------------

def test_function_with_no_body():
    # Should handle an empty function gracefully
    func_src = """
def test_empty():
    '''docstring'''
"""
    function = FunctionToOptimize(function_name="foo")
    call_positions = []
    new_node = parse_and_visit(func_src, call_positions, function)

def test_nested_calls_in_for_loop():
    # Should wrap calls inside for loop bodies
    func_src = """
def test_loop():
    for i in range(3):
        foo(i)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=3, col_no=8, end_col_offset=13)]
    new_node = parse_and_visit(func_src, call_positions, function)
    # Should find codeflash_wrap inside the for loop body
    for_node = next(x for x in new_node.body if isinstance(x, ast.For))
    call = get_first_call(for_node)

def test_multiple_calls_and_positions():
    # Should wrap only calls at specified positions
    func_src = """
def test_multi():
    foo(1)
    bar(2)
    foo(3)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    # Only wrap the first and last foo, not bar
    call_positions = [
        CodePosition(line_no=2, col_no=4, end_col_offset=9),
        CodePosition(line_no=4, col_no=4, end_col_offset=9)
    ]
    new_node = parse_and_visit(func_src, call_positions, function)
    calls = [n for n in ast.walk(new_node) if isinstance(n, ast.Call)]
    # Should have two codeflash_wrap and one bar (untouched)
    wrap_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "codeflash_wrap")
    bar_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "bar")

def test_handles_attribute_calls():
    # Should wrap attribute calls like obj.foo()
    func_src = """
def test_attr():
    obj.foo(42)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="obj.foo")
    call_positions = [CodePosition(line_no=2, col_no=4, end_col_offset=14)]
    new_node = parse_and_visit(func_src, call_positions, function)
    call = get_first_call(new_node)

def test_function_with_if_and_nested_call():
    # Should wrap calls inside if statements
    func_src = """
def test_if():
    if True:
        foo(99)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=3, col_no=8, end_col_offset=13)]
    new_node = parse_and_visit(func_src, call_positions, function)
    if_node = next(x for x in new_node.body if isinstance(x, ast.If))
    call = get_first_call(if_node)

def test_perf_mode_does_not_add_behavior_assignments():
    # Should not add codeflash_con/cur/iteration/close in PERF mode
    func_src = """
def test_perf():
    foo()
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=2, col_no=4, end_col_offset=9)]
    new_node = parse_and_visit(func_src, call_positions, function, mode=TestingMode.PERF)
    ids = [getattr(x.targets[0], "id", None) for x in new_node.body if isinstance(x, ast.Assign)]

def test_handles_function_with_decorators():
    # Should preserve existing decorators and add new ones if needed
    func_src = """
@some_decorator
def test_decorated():
    foo()
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=3, col_no=4, end_col_offset=9)]
    new_node = parse_and_visit(func_src, call_positions, function, test_framework="unittest")
    # Should still have some_decorator and timeout_decorator.timeout
    decorator_ids = [d.func.id if isinstance(d, ast.Call) else getattr(d, 'id', None) for d in new_node.decorator_list]

# ----------------- LARGE SCALE TEST CASES -----------------

def test_large_number_of_calls():
    # Should handle a function with many calls efficiently
    N = 100
    func_src = "def test_many():\n" + "\n".join([f"    foo({i})" for i in range(N)])
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=i+2, col_no=4, end_col_offset=9) for i in range(N)]
    new_node = parse_and_visit(func_src, call_positions, function)
    # Should wrap all foo() calls
    calls = [n for n in ast.walk(new_node) if isinstance(n, ast.Call)]
    wrap_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "codeflash_wrap")

def test_large_function_body_with_mixed_calls():
    # Should only wrap the correct calls in a large, mixed function
    N = 100
    func_src = "def test_mixed():\n"
    for i in range(N):
        if i % 2 == 0:
            func_src += f"    foo({i})\n"
        else:
            func_src += f"    bar({i})\n"
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    call_positions = [CodePosition(line_no=i+2, col_no=4, end_col_offset=9) for i in range(N) if i % 2 == 0]
    new_node = parse_and_visit(func_src, call_positions, function)
    calls = [n for n in ast.walk(new_node) if isinstance(n, ast.Call)]
    wrap_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "codeflash_wrap")
    foo_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "foo")
    bar_count = sum(1 for c in calls if isinstance(c.func, ast.Name) and c.func.id == "bar")

def test_large_nested_structure():
    # Should handle nested for/if/while with many calls
    func_src = """
def test_nested():
    for i in range(10):
        if i % 2 == 0:
            foo(i)
        else:
            bar(i)
"""
    function = FunctionToOptimize(function_name="foo", qualified_name="foo")
    # Only wrap foo(i) at line 5
    call_positions = [CodePosition(line_no=5, col_no=12, end_col_offset=17)]
    new_node = parse_and_visit(func_src, call_positions, function)
    # Find the for node
    for_node = next(x for x in new_node.body if isinstance(x, ast.For))
    # Find the if node inside the for
    if_node = next(x for x in ast.walk(for_node) if isinstance(x, ast.If))
    # foo(i) should be wrapped, bar(i) should not
    foo_call = get_first_call(if_node.body[0])
    bar_call = get_first_call(if_node.orelse[0])
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr363-2025-06-22T23.07.59 and push.

Codeflash

… (`part-1-windows-fixes`)

Here's an optimized rewrite of **your original code**, focusing on critical hotspots from the profiler data.

**Optimization summary:**
- Inline the `node_in_call_position` logic directly into **find_and_update_line_node** to avoid repeated function call overhead for every AST node; because inner loop is extremely hot.
- Pre-split self.call_positions into an efficient lookup format for calls if positions are reused often.
- Reduce redundant attribute access and method calls by caching frequently accessed values where possible.
- Move branching on the most frequent path (ast.Name) up, and short-circuit to avoid unnecessary checks.
- Fast path for common case: ast.Name, skipping .unparse and unnecessary packing/mapping.
- Avoid repeated `ast.Name(id="codeflash_loop_index", ctx=ast.Load())` construction by storing as a field (`self.ast_codeflash_loop_index` etc.) (since they're repeated many times for a single method walk, re-use them).
- Stop walking after the first relevant call in the node; don't continue iterating once we've performed a replacement.

Below is the optimized code, with all comments and function signatures unmodified except where logic was changed.



**Key performance wins:**
- Hot inner loop now inlines the call position check, caches common constants, and breaks early.
- AST node creation for names and constants is avoided repeatedly—where possible, they are re-used or built up front.
- Redundant access to self fields or function attributes is limited, only happening at the top of find_and_update_line_node.
- Fast path (ast.Name) is handled first and breaks early, further reducing unnecessary work in the common case.

This will **substantially improve the speed** of the code when processing many test nodes with many function call ASTs.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Jun 22, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡️ codeflash Optimization PR opened by Codeflash AI
Projects
None yet
Development

Successfully merging this pull request may close these issues.

0 participants