Skip to content

Commit da9df78

Browse files
committed
new tests
1 parent 771ba90 commit da9df78

File tree

2 files changed

+553
-60
lines changed

2 files changed

+553
-60
lines changed

codeflash/code_utils/code_replacer.py

Lines changed: 78 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
import isort
99
import libcst as cst
10-
import libcst.matchers as m
1110
from libcst.metadata import PositionProvider
1211

1312
from codeflash.cli_cmds.console import logger
@@ -41,29 +40,50 @@ def normalize_code(code: str) -> str:
4140
class AddRequestArgument(cst.CSTTransformer):
4241
METADATA_DEPENDENCIES = (PositionProvider,)
4342

44-
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002
45-
args = updated_node.params.params
46-
arg_names = {arg.name.value for arg in args}
47-
48-
# Skip if 'request' is already present
49-
if "request" in arg_names:
50-
return updated_node
51-
52-
# Create a new 'request' param
53-
request_param = cst.Param(name=cst.Name("request"))
54-
55-
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
56-
if args:
57-
first_arg = args[0].name.value
58-
if first_arg in {"self", "cls"}:
59-
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
60-
else:
61-
new_params = [request_param] + list(args) # noqa: RUF005
62-
else:
63-
new_params = [request_param]
64-
65-
new_param_list = updated_node.params.with_changes(params=new_params)
66-
return updated_node.with_changes(params=new_param_list)
43+
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
44+
# Matcher for '@fixture' or '@pytest.fixture'
45+
for decorator in original_node.decorators:
46+
dec = decorator.decorator
47+
48+
if isinstance(dec, cst.Call):
49+
func_name = ""
50+
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
51+
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
52+
func_name = "pytest.fixture"
53+
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
54+
func_name = "fixture"
55+
56+
if func_name:
57+
for arg in dec.args:
58+
if (
59+
arg.keyword
60+
and arg.keyword.value == "autouse"
61+
and isinstance(arg.value, cst.Name)
62+
and arg.value.value == "True"
63+
):
64+
args = updated_node.params.params
65+
arg_names = {arg.name.value for arg in args}
66+
67+
# Skip if 'request' is already present
68+
if "request" in arg_names:
69+
return updated_node
70+
71+
# Create a new 'request' param
72+
request_param = cst.Param(name=cst.Name("request"))
73+
74+
# Add 'request' as the first argument (after 'self' or 'cls' if needed)
75+
if args:
76+
first_arg = args[0].name.value
77+
if first_arg in {"self", "cls"}:
78+
new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005
79+
else:
80+
new_params = [request_param] + list(args) # noqa: RUF005
81+
else:
82+
new_params = [request_param]
83+
84+
new_param_list = updated_node.params.with_changes(params=new_params)
85+
return updated_node.with_changes(params=new_param_list)
86+
return updated_node
6787

6888

6989
class PytestMarkAdder(cst.CSTTransformer):
@@ -135,33 +155,41 @@ def _create_pytest_mark(self) -> cst.Decorator:
135155
class AutouseFixtureModifier(cst.CSTTransformer):
136156
def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef:
137157
# Matcher for '@fixture' or '@pytest.fixture'
138-
fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture"))
139-
140158
for decorator in original_node.decorators:
141-
if m.matches(
142-
decorator,
143-
m.Decorator(
144-
decorator=m.Call(
145-
func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))]
146-
)
147-
),
148-
):
149-
# Found a matching fixture with autouse=True
150-
151-
# 1. The original body of the function will become the 'else' block.
152-
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
153-
else_block = cst.Else(body=updated_node.body)
154-
155-
# 2. Create the new 'if' block that will exit the fixture early.
156-
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
157-
yield_statement = cst.parse_statement("yield")
158-
if_body = cst.IndentedBlock(body=[yield_statement])
159-
160-
# 3. Construct the full if/else statement.
161-
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
162-
163-
# 4. Replace the entire function's body with our new single statement.
164-
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
159+
dec = decorator.decorator
160+
161+
if isinstance(dec, cst.Call):
162+
func_name = ""
163+
if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name):
164+
if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest":
165+
func_name = "pytest.fixture"
166+
elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture":
167+
func_name = "fixture"
168+
169+
if func_name:
170+
for arg in dec.args:
171+
if (
172+
arg.keyword
173+
and arg.keyword.value == "autouse"
174+
and isinstance(arg.value, cst.Name)
175+
and arg.value.value == "True"
176+
):
177+
# Found a matching fixture with autouse=True
178+
179+
# 1. The original body of the function will become the 'else' block.
180+
# updated_node.body is an IndentedBlock, which is what cst.Else expects.
181+
else_block = cst.Else(body=updated_node.body)
182+
183+
# 2. Create the new 'if' block that will exit the fixture early.
184+
if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")')
185+
yield_statement = cst.parse_statement("yield")
186+
if_body = cst.IndentedBlock(body=[yield_statement])
187+
188+
# 3. Construct the full if/else statement.
189+
new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block)
190+
191+
# 4. Replace the entire function's body with our new single statement.
192+
return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement]))
165193
return updated_node
166194

167195

0 commit comments

Comments
 (0)