Skip to content

Commit

Permalink
assert/remove casts in semanal.py (#2341)
Browse files Browse the repository at this point in the history
  • Loading branch information
elazarg authored and gvanrossum committed Oct 27, 2016
1 parent ba85545 commit 7b6fbb7
Showing 1 changed file with 58 additions and 57 deletions.
115 changes: 58 additions & 57 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"""

from typing import (
List, Dict, Set, Tuple, cast, Any, TypeVar, Union, Optional, Callable
List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable
)

from mypy.nodes import (
Expand Down Expand Up @@ -280,9 +280,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
if defn.name() in self.type.names:
# Redefinition. Conditional redefinition is okay.
n = self.type.names[defn.name()].node
if self.is_conditional_func(n, defn):
defn.original_def = cast(FuncDef, n)
else:
if not self.set_original_def(n, defn):
self.name_already_defined(defn.name(), defn)
self.type.names[defn.name()] = SymbolTableNode(MDEF, defn)
self.prepare_method_signature(defn)
Expand All @@ -292,9 +290,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
if defn.name() in self.locals[-1]:
# Redefinition. Conditional redefinition is okay.
n = self.locals[-1][defn.name()].node
if self.is_conditional_func(n, defn):
defn.original_def = cast(FuncDef, n)
else:
if not self.set_original_def(n, defn):
self.name_already_defined(defn.name(), defn)
else:
self.add_local(defn, defn)
Expand All @@ -304,11 +300,7 @@ def visit_func_def(self, defn: FuncDef) -> None:
symbol = self.globals.get(defn.name())
if isinstance(symbol.node, FuncDef) and symbol.node != defn:
# This is redefinition. Conditional redefinition is okay.
original_def = symbol.node
if self.is_conditional_func(original_def, defn):
# Conditional function definition -- multiple defs are ok.
defn.original_def = original_def
else:
if not self.set_original_def(symbol.node, defn):
# Report error.
self.check_no_global(defn.name(), defn, True)
if phase_info == FUNCTION_FIRST_PHASE_POSTPONE_SECOND:
Expand Down Expand Up @@ -341,19 +333,22 @@ def prepare_method_signature(self, func: FuncDef) -> None:
leading_type = self.class_type(self.type)
else:
leading_type = fill_typevars(self.type)
sig = cast(FunctionLike, func.type)
func.type = replace_implicit_first_type(sig, leading_type)
func.type = replace_implicit_first_type(functype, leading_type)

def is_conditional_func(self, previous: Node, new: FuncDef) -> bool:
"""Does 'new' conditionally redefine 'previous'?
def set_original_def(self, previous: Node, new: FuncDef) -> bool:
"""If 'new' conditionally redefine 'previous', set 'previous' as original
We reject straight redefinitions of functions, as they are usually
a programming error. For example:
. def f(): ...
. def f(): ... # Error: 'f' redefined
"""
return isinstance(previous, (FuncDef, Var)) and new.is_conditional
if isinstance(previous, (FuncDef, Var)) and new.is_conditional:
new.original_def = previous
return True
else:
return False

def update_function_type_variables(self, defn: FuncDef) -> None:
"""Make any type variables in the signature of defn explicit.
Expand All @@ -362,8 +357,8 @@ def update_function_type_variables(self, defn: FuncDef) -> None:
if defn is generic.
"""
if defn.type:
functype = cast(CallableType, defn.type)
typevars = self.infer_type_variables(functype)
assert isinstance(defn.type, CallableType)
typevars = self.infer_type_variables(defn.type)
# Do not define a new type variable if already defined in scope.
typevars = [(name, tvar) for name, tvar in typevars
if not self.is_defined_type_var(name, defn)]
Expand All @@ -373,7 +368,7 @@ def update_function_type_variables(self, defn: FuncDef) -> None:
tvar[1].values, tvar[1].upper_bound,
tvar[1].variance)
for i, tvar in enumerate(typevars)]
functype.variables = defs
defn.type.variables = defs

def infer_type_variables(self,
type: CallableType) -> List[Tuple[str, TypeVarExpr]]:
Expand All @@ -387,8 +382,7 @@ def infer_type_variables(self,
tvars.append(tvar_expr)
return list(zip(names, tvars))

def find_type_variables_in_type(
self, type: Type) -> List[Tuple[str, TypeVarExpr]]:
def find_type_variables_in_type(self, type: Type) -> List[Tuple[str, TypeVarExpr]]:
"""Return a list of all unique type variable references in type.
This effectively does partial name binding, results of which are mostly thrown away.
Expand All @@ -398,7 +392,8 @@ def find_type_variables_in_type(
name = type.name
node = self.lookup_qualified(name, type)
if node and node.kind == UNBOUND_TVAR:
result.append((name, cast(TypeVarExpr, node.node)))
assert isinstance(node.node, TypeVarExpr)
result.append((name, node.node))
for arg in type.args:
result.extend(self.find_type_variables_in_type(arg))
elif isinstance(type, TypeList):
Expand All @@ -425,8 +420,9 @@ def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
item.is_overload = True
item.func.is_overload = True
item.accept(self)
t.append(cast(CallableType, function_type(item.func,
self.builtin_type('builtins.function'))))
callable = function_type(item.func, self.builtin_type('builtins.function'))
assert isinstance(callable, CallableType)
t.append(callable)
if item.func.is_property and i == 0:
# This defines a property, probably with a setter and/or deleter.
self.analyze_property_with_multi_part_definition(defn)
Expand Down Expand Up @@ -524,8 +520,9 @@ def add_func_type_variables_to_symbol_table(
nodes = [] # type: List[SymbolTableNode]
if defn.type:
tt = defn.type
assert isinstance(tt, CallableType)
items = tt.variables
names = self.type_var_names()
items = cast(CallableType, tt).variables
for item in items:
name = item.name
if name in names:
Expand All @@ -549,7 +546,8 @@ def bind_type_var(self, fullname: str, tvar_def: TypeVarDef,
return node

def check_function_signature(self, fdef: FuncItem) -> None:
sig = cast(CallableType, fdef.type)
sig = fdef.type
assert isinstance(sig, CallableType)
if len(sig.arg_types) < len(fdef.arguments):
self.fail('Type signature has too few arguments', fdef)
# Add dummy Any arguments to prevent crashes later.
Expand Down Expand Up @@ -725,7 +723,8 @@ def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]:
unbound = t
sym = self.lookup_qualified(unbound.name, unbound)
if sym is not None and sym.kind == UNBOUND_TVAR:
return unbound.name, cast(TypeVarExpr, sym.node)
assert isinstance(sym.node, TypeVarExpr)
return unbound.name, sym.node
return None

def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
Expand Down Expand Up @@ -922,13 +921,15 @@ def class_type(self, info: TypeInfo) -> Type:

def named_type(self, qualified_name: str, args: List[Type] = None) -> Instance:
sym = self.lookup_qualified(qualified_name, None)
return Instance(cast(TypeInfo, sym.node), args or [])
assert isinstance(sym.node, TypeInfo)
return Instance(sym.node, args or [])

def named_type_or_none(self, qualified_name: str, args: List[Type] = None) -> Instance:
sym = self.lookup_fully_qualified_or_none(qualified_name)
if not sym:
return None
return Instance(cast(TypeInfo, sym.node), args or [])
assert isinstance(sym.node, TypeInfo)
return Instance(sym.node, args or [])

def bind_class_type_variables_in_symbol_table(
self, info: TypeInfo) -> List[SymbolTableNode]:
Expand Down Expand Up @@ -1300,11 +1301,10 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False,
lval.accept(self)
elif (isinstance(lval, TupleExpr) or
isinstance(lval, ListExpr)):
items = cast(Any, lval).items
items = lval.items
if len(items) == 0 and isinstance(lval, TupleExpr):
self.fail("Can't assign to ()", lval)
self.analyze_tuple_or_list_lvalue(cast(Union[ListExpr, TupleExpr], lval),
add_global, explicit_type)
self.analyze_tuple_or_list_lvalue(lval, add_global, explicit_type)
elif isinstance(lval, StarExpr):
if nested:
self.analyze_lvalue(lval.expr, nested, add_global, explicit_type)
Expand All @@ -1318,9 +1318,7 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr],
explicit_type: bool = False) -> None:
"""Analyze an lvalue or assignment target that is a list or tuple."""
items = lval.items
star_exprs = [cast(StarExpr, item)
for item in items
if isinstance(item, StarExpr)]
star_exprs = [item for item in items if isinstance(item, StarExpr)]

if len(star_exprs) > 1:
self.fail('Two starred expressions in assignment', lval)
Expand Down Expand Up @@ -1452,14 +1450,14 @@ def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Opt
if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)):
self.fail("Argument 1 to NewType(...) must be a string literal", context)
has_failed = True
elif cast(StrExpr, call.args[0]).value != name:
elif args[0].value != name:
msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'"
self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context)
self.fail(msg.format(args[0].value, name), context)
has_failed = True

# Check second argument
try:
unanalyzed_type = expr_to_unanalyzed_type(call.args[1])
unanalyzed_type = expr_to_unanalyzed_type(args[1])
except TypeTranslationError:
self.fail("Argument 2 to NewType(...) must be a valid type", context)
return None
Expand Down Expand Up @@ -1497,7 +1495,8 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None:
if not call:
return

lvalue = cast(NameExpr, s.lvalues[0])
lvalue = s.lvalues[0]
assert isinstance(lvalue, NameExpr)
name = lvalue.name
if not lvalue.is_def:
if s.type:
Expand Down Expand Up @@ -1538,9 +1537,9 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo
or not call.arg_kinds[0] == ARG_POS):
self.fail("TypeVar() expects a string literal as first argument", context)
return False
if cast(StrExpr, call.args[0]).value != name:
elif call.args[0].value != name:
msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'"
self.fail(msg.format(cast(StrExpr, call.args[0]).value, name), context)
self.fail(msg.format(call.args[0].value, name), context)
return False
return True

Expand Down Expand Up @@ -2308,7 +2307,8 @@ def visit_member_expr(self, expr: MemberExpr) -> None:
# This branch handles the case foo.bar where foo is a module.
# In this case base.node is the module's MypyFile and we look up
# bar in its namespace. This must be done for all types of bar.
file = cast(MypyFile, base.node)
file = base.node
assert isinstance(file, MypyFile)
n = file.names.get(expr.name, None) if file is not None else None
if n:
n = self.normalize_type_alias(n, expr)
Expand Down Expand Up @@ -2513,7 +2513,8 @@ def lookup(self, name: str, ctx: Context) -> SymbolTableNode:
# 5. Builtins
b = self.globals.get('__builtins__', None)
if b:
table = cast(MypyFile, b.node).names
assert isinstance(b.node, MypyFile)
table = b.node.names
if name in table:
if name[0] == "_" and name[1] != "_":
self.name_not_defined(name, ctx)
Expand Down Expand Up @@ -2568,8 +2569,8 @@ def lookup_qualified(self, name: str, ctx: Context) -> SymbolTableNode:

def builtin_type(self, fully_qualified_name: str) -> Instance:
node = self.lookup_fully_qualified(fully_qualified_name)
info = cast(TypeInfo, node.node)
return Instance(info, [])
assert isinstance(node.node, TypeInfo)
return Instance(node.node, [])

def lookup_fully_qualified(self, name: str) -> SymbolTableNode:
"""Lookup a fully qualified name.
Expand All @@ -2581,10 +2582,12 @@ def lookup_fully_qualified(self, name: str) -> SymbolTableNode:
parts = name.split('.')
n = self.modules[parts[0]]
for i in range(1, len(parts) - 1):
n = cast(MypyFile, n.names[parts[i]].node)
return n.names[parts[-1]]
next_sym = n.names[parts[i]]
assert isinstance(next_sym.node, MypyFile)
n = next_sym.node
return n.names.get(parts[-1])

def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode:
def lookup_fully_qualified_or_none(self, name: str) -> Optional[SymbolTableNode]:
"""Lookup a fully qualified name.
Assume that the name is defined. This happens in the global namespace -- the local
Expand All @@ -2597,7 +2600,8 @@ def lookup_fully_qualified_or_none(self, name: str) -> SymbolTableNode:
next_sym = n.names.get(parts[i])
if not next_sym:
return None
n = cast(MypyFile, next_sym.node)
assert isinstance(next_sym.node, MypyFile)
n = next_sym.node
return n.names.get(parts[-1])

def qualified_name(self, n: str) -> str:
Expand Down Expand Up @@ -2811,11 +2815,7 @@ def visit_func_def(self, func: FuncDef) -> None:
# Ah this is an imported name. We can't resolve them now, so we'll postpone
# this until the main phase of semantic analysis.
return
original_def = original_sym.node
if sem.is_conditional_func(original_def, func):
# Conditional function definition -- multiple defs are ok.
func.original_def = cast(FuncDef, original_def)
else:
if not sem.set_original_def(original_sym.node, func):
# Report error.
sem.check_no_global(func.name(), func)
else:
Expand Down Expand Up @@ -3055,10 +3055,11 @@ def fill_typevars(typ: TypeInfo) -> Union[Instance, TupleType]:
def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike:
if isinstance(sig, CallableType):
return sig.copy_modified(arg_types=[new] + sig.arg_types[1:])
else:
sig = cast(Overloaded, sig)
elif isinstance(sig, Overloaded):
return Overloaded([cast(CallableType, replace_implicit_first_type(i, new))
for i in sig.items()])
else:
assert False


def set_callable_name(sig: Type, fdef: FuncDef) -> Type:
Expand Down

0 comments on commit 7b6fbb7

Please sign in to comment.