Skip to content

Commit 75a430f

Browse files
committed
βœ… test: fix remove_unused_named_expr ut
1 parent 982a9ba commit 75a430f

File tree

4 files changed

+10
-9
lines changed

4 files changed

+10
-9
lines changed

β€Žsrc/expr_simplifier/transforms/constant_folding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def fold_to_constant(node: ast.AST) -> ast.Constant:
1515

1616
class ConstantFolding(ast.NodeTransformer):
1717
def visit(self, node: ast.AST) -> ast.AST:
18-
transformed_node = super().visit(node)
18+
transformed_node = self.generic_visit(node)
1919
if isinstance(node, ast.BinOp) and isinstance(node.left, ast.Constant) and isinstance(node.right, ast.Constant):
2020
return fold_to_constant(node)
2121
if isinstance(node, ast.UnaryOp) and isinstance(node.operand, ast.Constant):

β€Žsrc/expr_simplifier/transforms/cse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self):
1818
super().__init__()
1919

2020
def visit(self, node: ast.AST) -> None:
21-
super().visit(node)
21+
self.generic_visit(node)
2222
expr_string = ast.unparse(node)
2323
if isinstance(node, ast.expr):
2424
if isinstance(node, ast.Name):
@@ -38,10 +38,10 @@ def __init__(self, subexpressions: dict[str, tuple[str, int]]):
3838
self.declared_symbols = set[str]()
3939
super().__init__()
4040

41-
def visit(self, node: ast.AST) -> ast.expr:
41+
def visit(self, node: ast.AST) -> ast.AST:
4242
expr_string = ast.unparse(node)
43-
transformed_node = super().visit(node)
44-
if isinstance(node, ast.expr) and expr_string in self.subexpressions:
43+
transformed_node = self.generic_visit(node)
44+
if isinstance(transformed_node, ast.expr) and expr_string in self.subexpressions:
4545
symbol, count = self.subexpressions[expr_string]
4646
if count > 1:
4747
if symbol not in self.declared_symbols:

β€Žsrc/expr_simplifier/transforms/remove_unused_named_expr.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@ def __init__(self, used_symbols: set[str]) -> None:
2020
self.used_symbols = used_symbols
2121

2222
def visit_NamedExpr(self, node: ast.NamedExpr) -> ast.expr:
23-
value = self.visit(node.value)
23+
transformed_node = self.generic_visit(node)
24+
assert isinstance(transformed_node, ast.NamedExpr)
2425
name = node.target.id
2526
if name not in self.used_symbols:
26-
return value
27-
return node
27+
return transformed_node.value
28+
return transformed_node
2829

2930

3031
def apply_remove_unused_named_expr(expr: ast.AST) -> ast.AST:

β€Žtests/test_transforms/test_remove_unused_named_expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@pytest.mark.parametrize(
1111
["expr", "expected"],
1212
[
13-
("(___x := a.b) + ___x", "a.b + ___x"),
13+
("(___x := a.b) + ___x", "(___x := a.b) + ___x"),
1414
("(___y := (___x := a.b)) + ___y", "(___y := a.b) + ___y"),
1515
("(___y := (___x := a.b))", "a.b"),
1616
],

0 commit comments

Comments
Β (0)