Skip to content

Give "as" variables in with statements separate scopes #12254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Feb 25, 2022
11 changes: 7 additions & 4 deletions mypy/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from mypy.fscache import FileSystemCache
from mypy.metastore import MetadataStore, FilesystemMetadataStore, SqliteMetadataStore
from mypy.typestate import TypeState, reset_global_state
from mypy.renaming import VariableRenameVisitor
from mypy.renaming import VariableRenameVisitor, LimitedVariableRenameVisitor
from mypy.config_parser import parse_mypy_comments
from mypy.freetree import free_tree
from mypy.stubinfo import legacy_bundled_packages, is_legacy_bundled_package
Expand Down Expand Up @@ -2119,9 +2119,12 @@ def semantic_analysis_pass1(self) -> None:
analyzer.visit_file(self.tree, self.xpath, self.id, options)
# TODO: Do this while constructing the AST?
self.tree.names = SymbolTable()
if options.allow_redefinition:
# Perform renaming across the AST to allow variable redefinitions
self.tree.accept(VariableRenameVisitor())
if not self.tree.is_stub:
# Always perform some low-key variable renaming
self.tree.accept(LimitedVariableRenameVisitor())
if options.allow_redefinition:
# Perform more renaming across the AST to allow variable redefinitions
self.tree.accept(VariableRenameVisitor())

def add_dependency(self, dep: str) -> None:
if dep not in self.dependencies_set:
Expand Down
171 changes: 162 additions & 9 deletions mypy/renaming.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from contextlib import contextmanager
from typing import Dict, Iterator, List
from typing import Dict, Iterator, List, Set
from typing_extensions import Final

from mypy.nodes import (
Block, AssignmentStmt, NameExpr, MypyFile, FuncDef, Lvalue, ListExpr, TupleExpr,
WhileStmt, ForStmt, BreakStmt, ContinueStmt, TryStmt, WithStmt, MatchStmt, StarExpr,
ImportFrom, MemberExpr, IndexExpr, Import, ClassDef
ImportFrom, MemberExpr, IndexExpr, Import, ImportAll, ClassDef
)
from mypy.patterns import AsPattern
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -262,15 +262,9 @@ def flush_refs(self) -> None:
# as it will be publicly visible outside the module.
to_rename = refs[:-1]
for i, item in enumerate(to_rename):
self.rename_refs(item, i)
rename_refs(item, i)
self.refs.pop()

def rename_refs(self, names: List[NameExpr], index: int) -> None:
name = names[0].name
new_name = name + "'" * (index + 1)
for expr in names:
expr.name = new_name

# Helpers for determining which assignments define new variables

def clear(self) -> None:
Expand Down Expand Up @@ -392,3 +386,162 @@ def record_assignment(self, name: str, can_be_redefined: bool) -> bool:
else:
# Assigns to an existing variable.
return False


class LimitedVariableRenameVisitor(TraverserVisitor):
"""Perform some limited variable renaming in with statements.

This allows reusing a variable in multiple with statements with
different types. For example, the two instances of 'x' can have
incompatible types:

with C() as x:
f(x)
with D() as x:
g(x)

The above code gets renamed conceptually into this (not valid Python!):

with C() as x':
f(x')
with D() as x:
g(x)

If there's a reference to a variable defined in 'with' outside the
statement, or if there's any trickiness around variable visibility
(e.g. function definitions), we give up and won't perform renaming.

The main use case is to allow binding both readable and writable
binary files into the same variable. These have different types:

with open(fnam, 'rb') as f: ...
with open(fnam, 'wb') as f: ...
"""

def __init__(self) -> None:
# Short names of variables bound in with statements using "as"
# in a surrounding scope
self.bound_vars: List[str] = []
# Names that can't be safely renamed, per scope ('*' means that
# no names can be renamed)
self.skipped: List[Set[str]] = []
# References to variables that we may need to rename. List of
# scopes; each scope is a mapping from name to list of collections
# of names that refer to the same logical variable.
self.refs: List[Dict[str, List[List[NameExpr]]]] = []

def visit_mypy_file(self, file_node: MypyFile) -> None:
"""Rename variables within a file.

This is the main entry point to this class.
"""
with self.enter_scope():
for d in file_node.defs:
d.accept(self)

def visit_func_def(self, fdef: FuncDef) -> None:
self.reject_redefinition_of_vars_in_scope()
with self.enter_scope():
for arg in fdef.arguments:
self.record_skipped(arg.variable.name)
super().visit_func_def(fdef)

def visit_class_def(self, cdef: ClassDef) -> None:
self.reject_redefinition_of_vars_in_scope()
with self.enter_scope():
super().visit_class_def(cdef)

def visit_with_stmt(self, stmt: WithStmt) -> None:
for expr in stmt.expr:
expr.accept(self)
old_len = len(self.bound_vars)
for target in stmt.target:
if target is not None:
self.analyze_lvalue(target)
for target in stmt.target:
if target:
target.accept(self)
stmt.body.accept(self)

while len(self.bound_vars) > old_len:
self.bound_vars.pop()

def analyze_lvalue(self, lvalue: Lvalue) -> None:
if isinstance(lvalue, NameExpr):
name = lvalue.name
if name in self.bound_vars:
# Name bound in a surrounding with statement, so it can be renamed
self.visit_name_expr(lvalue)
else:
var_info = self.refs[-1]
if name not in var_info:
var_info[name] = []
var_info[name].append([])
self.bound_vars.append(name)
elif isinstance(lvalue, (ListExpr, TupleExpr)):
for item in lvalue.items:
self.analyze_lvalue(item)
elif isinstance(lvalue, MemberExpr):
lvalue.expr.accept(self)
elif isinstance(lvalue, IndexExpr):
lvalue.base.accept(self)
lvalue.index.accept(self)
elif isinstance(lvalue, StarExpr):
self.analyze_lvalue(lvalue.expr)

def visit_import(self, imp: Import) -> None:
# We don't support renaming imports
for id, as_id in imp.ids:
self.record_skipped(as_id or id)

def visit_import_from(self, imp: ImportFrom) -> None:
# We don't support renaming imports
for id, as_id in imp.names:
self.record_skipped(as_id or id)

def visit_import_all(self, imp: ImportAll) -> None:
# Give up, since we don't know all imported names yet
self.reject_redefinition_of_vars_in_scope()

def visit_name_expr(self, expr: NameExpr) -> None:
name = expr.name
if name in self.bound_vars:
# Record reference so that it can be renamed later
for scope in reversed(self.refs):
if name in scope:
scope[name][-1].append(expr)
else:
self.record_skipped(name)

@contextmanager
def enter_scope(self) -> Iterator[None]:
self.skipped.append(set())
self.refs.append({})
yield None
self.flush_refs()

def reject_redefinition_of_vars_in_scope(self) -> None:
self.record_skipped('*')

def record_skipped(self, name: str) -> None:
self.skipped[-1].add(name)

def flush_refs(self) -> None:
ref_dict = self.refs.pop()
skipped = self.skipped.pop()
if '*' not in skipped:
for name, refs in ref_dict.items():
if len(refs) <= 1 or name in skipped:
continue
# At module top level we must not rename the final definition,
# as it may be publicly visible
to_rename = refs[:-1]
for i, item in enumerate(to_rename):
rename_refs(item, i)


def rename_refs(names: List[NameExpr], index: int) -> None:
name = names[0].name
new_name = name + "'" * (index + 1)
for expr in names:
expr.name = new_name
41 changes: 41 additions & 0 deletions test-data/unit/check-python310.test
Original file line number Diff line number Diff line change
Expand Up @@ -1175,3 +1175,44 @@ def foo(value) -> int: # E: Missing return statement
return 1
case 2:
return 2

[case testWithStatementScopeAndMatchStatement]
from m import A, B

with A() as x:
pass
with B() as x: \
# E: Incompatible types in assignment (expression has type "B", variable has type "A")
pass

with A() as y:
pass
with B() as y: \
# E: Incompatible types in assignment (expression has type "B", variable has type "A")
pass

with A() as z:
pass
with B() as z: \
# E: Incompatible types in assignment (expression has type "B", variable has type "A")
pass

with A() as zz:
pass
with B() as zz: \
# E: Incompatible types in assignment (expression has type "B", variable has type "A")
pass

match x:
case str(y) as z:
zz = y

[file m.pyi]
from typing import Any

class A:
def __enter__(self) -> A: ...
def __exit__(self, x, y, z) -> None: ...
class B:
def __enter__(self) -> B: ...
def __exit__(self, x, y, z) -> None: ...
Loading