@@ -661,7 +661,9 @@ def __init__(
661
661
self .enable_assertion_pass_hook = False
662
662
self .source = source
663
663
self .scope : tuple [ast .AST , ...] = ()
664
- self .variables_overwrite : defaultdict [tuple [ast .AST , ...], Dict [str , str ]] = defaultdict (dict )
664
+ self .variables_overwrite : defaultdict [
665
+ tuple [ast .AST , ...], Dict [str , str ]
666
+ ] = defaultdict (dict )
665
667
666
668
def run (self , mod : ast .Module ) -> None :
667
669
"""Find all assert statements in *mod* and rewrite them."""
@@ -1049,16 +1051,19 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
1049
1051
new_args = []
1050
1052
new_kwargs = []
1051
1053
for arg in call .args :
1052
- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (self .scope , {}):
1053
- arg = self .variables_overwrite [self .scope ][arg .id ] # type:ignore[assignment]
1054
+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1055
+ self .scope , {}
1056
+ ):
1057
+ arg = self .variables_overwrite [self .scope ][
1058
+ arg .id
1059
+ ] # type:ignore[assignment]
1054
1060
res , expl = self .visit (arg )
1055
1061
arg_expls .append (expl )
1056
1062
new_args .append (res )
1057
1063
for keyword in call .keywords :
1058
- if (
1059
- isinstance (keyword .value , ast .Name )
1060
- and keyword .value .id in self .variables_overwrite .get (self .scope , {})
1061
- ):
1064
+ if isinstance (
1065
+ keyword .value , ast .Name
1066
+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1062
1067
keyword .value = self .variables_overwrite [self .scope ][
1063
1068
keyword .value .id
1064
1069
] # type:ignore[assignment]
@@ -1095,7 +1100,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1095
1100
def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
1096
1101
self .push_format_context ()
1097
1102
# We first check if we have overwritten a variable in the previous assert
1098
- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1103
+ if isinstance (
1104
+ comp .left , ast .Name
1105
+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1099
1106
comp .left = self .variables_overwrite [self .scope ][
1100
1107
comp .left .id
1101
1108
] # type:ignore[assignment]
0 commit comments