|
7 | 7 |
|
8 | 8 | import isort |
9 | 9 | import libcst as cst |
10 | | -import libcst.matchers as m |
11 | 10 | from libcst.metadata import PositionProvider |
12 | 11 |
|
13 | 12 | from codeflash.cli_cmds.console import logger |
@@ -41,29 +40,50 @@ def normalize_code(code: str) -> str: |
41 | 40 | class AddRequestArgument(cst.CSTTransformer): |
42 | 41 | METADATA_DEPENDENCIES = (PositionProvider,) |
43 | 42 |
|
44 | | - def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: # noqa: ARG002 |
45 | | - args = updated_node.params.params |
46 | | - arg_names = {arg.name.value for arg in args} |
47 | | - |
48 | | - # Skip if 'request' is already present |
49 | | - if "request" in arg_names: |
50 | | - return updated_node |
51 | | - |
52 | | - # Create a new 'request' param |
53 | | - request_param = cst.Param(name=cst.Name("request")) |
54 | | - |
55 | | - # Add 'request' as the first argument (after 'self' or 'cls' if needed) |
56 | | - if args: |
57 | | - first_arg = args[0].name.value |
58 | | - if first_arg in {"self", "cls"}: |
59 | | - new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 |
60 | | - else: |
61 | | - new_params = [request_param] + list(args) # noqa: RUF005 |
62 | | - else: |
63 | | - new_params = [request_param] |
64 | | - |
65 | | - new_param_list = updated_node.params.with_changes(params=new_params) |
66 | | - return updated_node.with_changes(params=new_param_list) |
| 43 | + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: |
| 44 | + # Matcher for '@fixture' or '@pytest.fixture' |
| 45 | + for decorator in original_node.decorators: |
| 46 | + dec = decorator.decorator |
| 47 | + |
| 48 | + if isinstance(dec, cst.Call): |
| 49 | + func_name = "" |
| 50 | + if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): |
| 51 | + if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": |
| 52 | + func_name = "pytest.fixture" |
| 53 | + elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": |
| 54 | + func_name = "fixture" |
| 55 | + |
| 56 | + if func_name: |
| 57 | + for arg in dec.args: |
| 58 | + if ( |
| 59 | + arg.keyword |
| 60 | + and arg.keyword.value == "autouse" |
| 61 | + and isinstance(arg.value, cst.Name) |
| 62 | + and arg.value.value == "True" |
| 63 | + ): |
| 64 | + args = updated_node.params.params |
| 65 | + arg_names = {arg.name.value for arg in args} |
| 66 | + |
| 67 | + # Skip if 'request' is already present |
| 68 | + if "request" in arg_names: |
| 69 | + return updated_node |
| 70 | + |
| 71 | + # Create a new 'request' param |
| 72 | + request_param = cst.Param(name=cst.Name("request")) |
| 73 | + |
| 74 | + # Add 'request' as the first argument (after 'self' or 'cls' if needed) |
| 75 | + if args: |
| 76 | + first_arg = args[0].name.value |
| 77 | + if first_arg in {"self", "cls"}: |
| 78 | + new_params = [args[0], request_param] + list(args[1:]) # noqa: RUF005 |
| 79 | + else: |
| 80 | + new_params = [request_param] + list(args) # noqa: RUF005 |
| 81 | + else: |
| 82 | + new_params = [request_param] |
| 83 | + |
| 84 | + new_param_list = updated_node.params.with_changes(params=new_params) |
| 85 | + return updated_node.with_changes(params=new_param_list) |
| 86 | + return updated_node |
67 | 87 |
|
68 | 88 |
|
69 | 89 | class PytestMarkAdder(cst.CSTTransformer): |
@@ -135,33 +155,41 @@ def _create_pytest_mark(self) -> cst.Decorator: |
135 | 155 | class AutouseFixtureModifier(cst.CSTTransformer): |
136 | 156 | def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: |
137 | 157 | # Matcher for '@fixture' or '@pytest.fixture' |
138 | | - fixture_decorator_func = m.Name("fixture") | m.Attribute(value=m.Name("pytest"), attr=m.Name("fixture")) |
139 | | - |
140 | 158 | for decorator in original_node.decorators: |
141 | | - if m.matches( |
142 | | - decorator, |
143 | | - m.Decorator( |
144 | | - decorator=m.Call( |
145 | | - func=fixture_decorator_func, args=[m.Arg(value=m.Name("True"), keyword=m.Name("autouse"))] |
146 | | - ) |
147 | | - ), |
148 | | - ): |
149 | | - # Found a matching fixture with autouse=True |
150 | | - |
151 | | - # 1. The original body of the function will become the 'else' block. |
152 | | - # updated_node.body is an IndentedBlock, which is what cst.Else expects. |
153 | | - else_block = cst.Else(body=updated_node.body) |
154 | | - |
155 | | - # 2. Create the new 'if' block that will exit the fixture early. |
156 | | - if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') |
157 | | - yield_statement = cst.parse_statement("yield") |
158 | | - if_body = cst.IndentedBlock(body=[yield_statement]) |
159 | | - |
160 | | - # 3. Construct the full if/else statement. |
161 | | - new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) |
162 | | - |
163 | | - # 4. Replace the entire function's body with our new single statement. |
164 | | - return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) |
| 159 | + dec = decorator.decorator |
| 160 | + |
| 161 | + if isinstance(dec, cst.Call): |
| 162 | + func_name = "" |
| 163 | + if isinstance(dec.func, cst.Attribute) and isinstance(dec.func.value, cst.Name): |
| 164 | + if dec.func.attr.value == "fixture" and dec.func.value.value == "pytest": |
| 165 | + func_name = "pytest.fixture" |
| 166 | + elif isinstance(dec.func, cst.Name) and dec.func.value == "fixture": |
| 167 | + func_name = "fixture" |
| 168 | + |
| 169 | + if func_name: |
| 170 | + for arg in dec.args: |
| 171 | + if ( |
| 172 | + arg.keyword |
| 173 | + and arg.keyword.value == "autouse" |
| 174 | + and isinstance(arg.value, cst.Name) |
| 175 | + and arg.value.value == "True" |
| 176 | + ): |
| 177 | + # Found a matching fixture with autouse=True |
| 178 | + |
| 179 | + # 1. The original body of the function will become the 'else' block. |
| 180 | + # updated_node.body is an IndentedBlock, which is what cst.Else expects. |
| 181 | + else_block = cst.Else(body=updated_node.body) |
| 182 | + |
| 183 | + # 2. Create the new 'if' block that will exit the fixture early. |
| 184 | + if_test = cst.parse_expression('request.node.get_closest_marker("codeflash_no_autouse")') |
| 185 | + yield_statement = cst.parse_statement("yield") |
| 186 | + if_body = cst.IndentedBlock(body=[yield_statement]) |
| 187 | + |
| 188 | + # 3. Construct the full if/else statement. |
| 189 | + new_if_statement = cst.If(test=if_test, body=if_body, orelse=else_block) |
| 190 | + |
| 191 | + # 4. Replace the entire function's body with our new single statement. |
| 192 | + return updated_node.with_changes(body=cst.IndentedBlock(body=[new_if_statement])) |
165 | 193 | return updated_node |
166 | 194 |
|
167 | 195 |
|
|
0 commit comments