@@ -703,24 +703,17 @@ def run(self, mod: ast.Module) -> None:
703
703
return
704
704
pos = 0
705
705
for item in mod .body :
706
- if (
707
- expect_docstring
708
- and isinstance (item , ast .Expr )
709
- and isinstance (item .value , ast .Constant )
710
- and isinstance (item .value .value , str )
711
- ):
712
- doc = item .value .value
713
- if self .is_rewrite_disabled (doc ):
714
- return
715
- expect_docstring = False
716
- elif (
717
- isinstance (item , ast .ImportFrom )
718
- and item .level == 0
719
- and item .module == "__future__"
720
- ):
721
- pass
722
- else :
723
- break
706
+ match item :
707
+ case ast .Expr (value = ast .Constant (value = str () as doc )) if (
708
+ expect_docstring
709
+ ):
710
+ if self .is_rewrite_disabled (doc ):
711
+ return
712
+ expect_docstring = False
713
+ case ast .ImportFrom (level = 0 , module = "__future__" ):
714
+ pass
715
+ case _:
716
+ break
724
717
pos += 1
725
718
# Special case: for a decorated function, set the lineno to that of the
726
719
# first decorator, not the `def`. Issue #4984.
@@ -1017,20 +1010,17 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
1017
1010
# cond is set in a prior loop iteration below
1018
1011
self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa: F821
1019
1012
self .expl_stmts = fail_inner
1020
- # Check if the left operand is a ast.NamedExpr and the value has already been visited
1021
- if (
1022
- isinstance (v , ast .Compare )
1023
- and isinstance (v .left , ast .NamedExpr )
1024
- and v .left .target .id
1025
- in [
1026
- ast_expr .id
1027
- for ast_expr in boolop .values [:i ]
1028
- if hasattr (ast_expr , "id" )
1029
- ]
1030
- ):
1031
- pytest_temp = self .variable ()
1032
- self .variables_overwrite [self .scope ][v .left .target .id ] = v .left # type:ignore[assignment]
1033
- v .left .target .id = pytest_temp
1013
+ match v :
1014
+ # Check if the left operand is an ast.NamedExpr and the value has already been visited
1015
+ case ast .Compare (
1016
+ left = ast .NamedExpr (target = ast .Name (id = target_id ))
1017
+ ) if target_id in [
1018
+ e .id for e in boolop .values [:i ] if hasattr (e , "id" )
1019
+ ]:
1020
+ pytest_temp = self .variable ()
1021
+ self .variables_overwrite [self .scope ][target_id ] = v .left # type:ignore[assignment]
1022
+ # mypy's false positive, we're checking that the 'target' attribute exists.
1023
+ v .left .target .id = pytest_temp # type:ignore[attr-defined]
1034
1024
self .push_format_context ()
1035
1025
res , expl = self .visit (v )
1036
1026
body .append (ast .Assign ([ast .Name (res_var , ast .Store ())], res ))
@@ -1080,10 +1070,11 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
1080
1070
arg_expls .append (expl )
1081
1071
new_args .append (res )
1082
1072
for keyword in call .keywords :
1083
- if isinstance (
1084
- keyword .value , ast .Name
1085
- ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1086
- keyword .value = self .variables_overwrite [self .scope ][keyword .value .id ] # type:ignore[assignment]
1073
+ match keyword .value :
1074
+ case ast .Name (id = id ) if id in self .variables_overwrite .get (
1075
+ self .scope , {}
1076
+ ):
1077
+ keyword .value = self .variables_overwrite [self .scope ][id ] # type:ignore[assignment]
1087
1078
res , expl = self .visit (keyword .value )
1088
1079
new_kwargs .append (ast .keyword (keyword .arg , res ))
1089
1080
if keyword .arg :
@@ -1119,12 +1110,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
1119
1110
def visit_Compare (self , comp : ast .Compare ) -> tuple [ast .expr , str ]:
1120
1111
self .push_format_context ()
1121
1112
# We first check if we have overwritten a variable in the previous assert
1122
- if isinstance (
1123
- comp .left , ast .Name
1124
- ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1125
- comp .left = self .variables_overwrite [self .scope ][comp .left .id ] # type:ignore[assignment]
1126
- if isinstance (comp .left , ast .NamedExpr ):
1127
- self .variables_overwrite [self .scope ][comp .left .target .id ] = comp .left # type:ignore[assignment]
1113
+ match comp .left :
1114
+ case ast .Name (id = name_id ) if name_id in self .variables_overwrite .get (
1115
+ self .scope , {}
1116
+ ):
1117
+ comp .left = self .variables_overwrite [self .scope ][name_id ] # type: ignore[assignment]
1118
+ case ast .NamedExpr (target = ast .Name (id = target_id )):
1119
+ self .variables_overwrite [self .scope ][target_id ] = comp .left # type: ignore[assignment]
1128
1120
left_res , left_expl = self .visit (comp .left )
1129
1121
if isinstance (comp .left , ast .Compare | ast .BoolOp ):
1130
1122
left_expl = f"({ left_expl } )"
@@ -1136,13 +1128,14 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
1136
1128
syms : list [ast .expr ] = []
1137
1129
results = [left_res ]
1138
1130
for i , op , next_operand in it :
1139
- if (
1140
- isinstance (next_operand , ast .NamedExpr )
1141
- and isinstance (left_res , ast .Name )
1142
- and next_operand .target .id == left_res .id
1143
- ):
1144
- next_operand .target .id = self .variable ()
1145
- self .variables_overwrite [self .scope ][left_res .id ] = next_operand # type:ignore[assignment]
1131
+ match (next_operand , left_res ):
1132
+ case (
1133
+ ast .NamedExpr (target = ast .Name (id = target_id )),
1134
+ ast .Name (id = name_id ),
1135
+ ) if target_id == name_id :
1136
+ next_operand .target .id = self .variable ()
1137
+ self .variables_overwrite [self .scope ][name_id ] = next_operand # type: ignore[assignment]
1138
+
1146
1139
next_res , next_expl = self .visit (next_operand )
1147
1140
if isinstance (next_operand , ast .Compare | ast .BoolOp ):
1148
1141
next_expl = f"({ next_expl } )"
0 commit comments