Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison chaining #439

Merged
merged 9 commits into from
Sep 15, 2014
5 changes: 4 additions & 1 deletion mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TypeApplication, DictExpr, SliceExpr, FuncExpr, TempNode, SymbolTableNode,
Context, ListComprehension, ConditionalExpr, GeneratorExpr,
Decorator, SetExpr, PassStmt, TypeVarExpr, UndefinedExpr, PrintStmt,
LITERAL_TYPE, BreakStmt, ContinueStmt
LITERAL_TYPE, BreakStmt, ContinueStmt, ComparisonExpr
)
from mypy.nodes import function_type, method_type
from mypy import nodes
Expand Down Expand Up @@ -1579,6 +1579,9 @@ def visit_float_expr(self, e: FloatExpr) -> Type:
def visit_op_expr(self, e: OpExpr) -> Type:
return self.expr_checker.visit_op_expr(e)

def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
return self.expr_checker.visit_comparison_expr(e)

def visit_unary_expr(self, e: UnaryExpr) -> Type:
return self.expr_checker.visit_unary_expr(e)

Expand Down
84 changes: 59 additions & 25 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
OpExpr, UnaryExpr, IndexExpr, CastExpr, TypeApplication, ListExpr,
TupleExpr, DictExpr, FuncExpr, SuperExpr, ParenExpr, SliceExpr, Context,
ListComprehension, GeneratorExpr, SetExpr, MypyFile, Decorator,
UndefinedExpr, ConditionalExpr, TempNode, LITERAL_TYPE
UndefinedExpr, ConditionalExpr, ComparisonExpr, TempNode, LITERAL_TYPE
)
from mypy.errors import Errors
from mypy.nodes import function_type, method_type
Expand Down Expand Up @@ -745,38 +745,72 @@ def visit_op_expr(self, e: OpExpr) -> Type:
# Expressions of form [...] * e get special type inference.
return self.check_list_multiply(e)
left_type = self.accept(e.left)
right_type = self.accept(e.right) # TODO only evaluate if needed
if e.op == 'in' or e.op == 'not in':
local_errors = self.msg.copy()
result, method_type = self.check_op_local('__contains__', right_type,
e.left, e, local_errors)
if (local_errors.is_errors() and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(right_type)):
itertype = self.chk.analyse_iterable_item_type(e.right)
method_type = Callable([left_type], [nodes.ARG_POS], [None],
self.chk.bool_type(), False)
result = self.chk.bool_type()
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types('in', left_type, right_type, e)
else:
self.msg.add_errors(local_errors)
e.method_type = method_type
if e.op == 'in':
return result
else:
return self.chk.bool_type()
elif e.op in nodes.op_methods:

if e.op in nodes.op_methods:
method = self.get_operator_method(e.op)
result, method_type = self.check_op(method, left_type, e.right, e,
allow_reverse=True)
e.method_type = method_type
return result
elif e.op == 'is' or e.op == 'is not':
return self.chk.bool_type()
else:
raise RuntimeError('Unknown operator {}'.format(e.op))

def visit_comparison_expr(self, e: ComparisonExpr) -> Type:
"""Type check a comparison expression.

Comparison expressions are type checked consecutive-pair-wise
That is, 'a < b > c == d' is check as 'a < b and b > c and c == d'
"""
result = None # type: mypy.types.Type

# Check each consecutive operand pair and their operator
for left, right, operator in zip(e.operands, e.operands[1:], e.operators):
left_type = self.accept(left)

method_type = None # type: mypy.types.Type

if operator == 'in' or operator == 'not in':
right_type = self.accept(right) # TODO only evaluate if needed

local_errors = self.msg.copy()
sub_result, method_type = self.check_op_local('__contains__', right_type,
left, e, local_errors)
if (local_errors.is_errors() and
# is_valid_var_arg is True for any Iterable
self.is_valid_var_arg(right_type)):
itertype = self.chk.analyse_iterable_item_type(right)
method_type = Callable([left_type], [nodes.ARG_POS], [None],
self.chk.bool_type(), False)
sub_result = self.chk.bool_type()
if not is_subtype(left_type, itertype):
self.msg.unsupported_operand_types('in', left_type, right_type, e)
else:
self.msg.add_errors(local_errors)
if operator == 'not in':
sub_result = self.chk.bool_type()
elif operator in nodes.op_methods:
method = self.get_operator_method(operator)
sub_result, method_type = self.check_op(method, left_type, right, e,
allow_reverse=True)

elif operator == 'is' or operator == 'is not':
sub_result = self.chk.bool_type()
method_type = None
else:
raise RuntimeError('Unknown comparison operator {}'.format(operator))

e.method_types.append(method_type)

# Determine type of boolean-and of result and sub_result
if result == None:
result = sub_result
else:
# TODO: check on void needed?
self.check_not_void(sub_result, e)
result = join.join_types(result, sub_result, self.chk.basic_types())

return result

def get_operator_method(self, op: str) -> str:
if op == '/' and self.chk.pyversion == 2:
# TODO also check for "from __future__ import division"
Expand Down
9 changes: 6 additions & 3 deletions mypy/noderepr.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,11 @@ def __init__(self, dot: Any, name: Any) -> None:
self.dot = dot
self.name = name

class ComparisonExprRepr:
def __init__(self, operators: List[Any]) -> None:
# List of tupples of (op, op2).
# Note: op2 may be empty; it is used for "is not" and "not in".
self.operators = operators

class CallExprRepr:
def __init__(self, lparen: Any, commas: List[Token], star: Any, star2: Any,
Expand Down Expand Up @@ -254,10 +259,8 @@ def __init__(self, op: Any) -> None:


class OpExprRepr:
def __init__(self, op: Any, op2: Any) -> None:
# Note: op2 may be empty; it is used for "is not" and "not in".
def __init__(self, op: Any) -> None:
self.op = op
self.op2 = op2


class CastExprRepr:
Expand Down
26 changes: 23 additions & 3 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,13 +1016,13 @@ def accept(self, visitor: NodeVisitor[T]) -> T:


class OpExpr(Node):
"""Binary operation (other than . or [], which have specific nodes)."""
"""Binary operation (other than . or [] or comparison operators,
which have specific nodes)."""

op = ''
left = Undefined(Node)
right = Undefined(Node)
# Inferred type for the operator method type (when relevant; None for
# 'is').
# Inferred type for the operator method type (when relevant).
method_type = None # type: mypy.types.Type

def __init__(self, op: str, left: Node, right: Node) -> None:
Expand All @@ -1036,6 +1036,26 @@ def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_op_expr(self)


class ComparisonExpr(Node):
"""Comparison expression (e.g. a < b > c < d)."""

operators = Undefined(List[str])
operands = Undefined(List[Node])
# Inferred type for the operator methods (when relevant; None for 'is').
method_types = Undefined(List["mypy.types.Type"])

def __init__(self, operators: List[str], operands: List[Node]) -> None:
self.operators = operators
self.operands = operands
self.method_types = []
self.literal = min(o.literal for o in self.operands)
self.literal_hash = ( ('Comparison',) + tuple(operators) +
tuple(o.literal_hash for o in operands) )

def accept(self, visitor: NodeVisitor[T]) -> T:
return visitor.visit_comparison_expr(self)


class SliceExpr(Node):
"""Slice expression (e.g. 'x:y', 'x:', '::2' or ':').

Expand Down
9 changes: 8 additions & 1 deletion mypy/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,16 @@ def visit_call_expr(self, o):

def visit_op_expr(self, o):
self.node(o.left)
self.tokens([o.repr.op, o.repr.op2])
self.tokens([o.repr.op])
self.node(o.right)

def visit_comparison_expr(self, o):
self.node(o.operands[0])
for ops, operand in zip(o.repr.operators, o.operands[1:]):
# ops = op, op2
self.tokens(list(ops))
self.node(operand)

def visit_cast_expr(self, o):
self.token(o.repr.lparen)
self.type(o.type)
Expand Down
58 changes: 44 additions & 14 deletions mypy/parse.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
TupleExpr, GeneratorExpr, ListComprehension, ListExpr, ConditionalExpr,
DictExpr, SetExpr, NameExpr, IntExpr, StrExpr, BytesExpr, UnicodeExpr,
FloatExpr, CallExpr, SuperExpr, MemberExpr, IndexExpr, SliceExpr, OpExpr,
UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase
UnaryExpr, FuncExpr, TypeApplication, PrintStmt, ImportBase, ComparisonExpr
)
from mypy import nodes
from mypy import noderepr
Expand Down Expand Up @@ -57,6 +57,8 @@
'+=', '-=', '*=', '/=', '//=', '%=', '**=', '|=', '&=', '^=', '>>=',
'<<='])

op_comp = set([
'>', '<', '==', '>=', '<=', '<>', '!=', 'is', 'is', 'in', 'not'])

none = Token('') # Empty token

Expand Down Expand Up @@ -1107,7 +1109,10 @@ def parse_expression(self, prec: int = 0) -> Node:
# Either "not in" or an error.
op_prec = precedence['in']
if op_prec > prec:
expr = self.parse_bin_op_expr(expr, op_prec)
if op in op_comp:
expr = self.parse_comparison_expr(expr, op_prec)
else:
expr = self.parse_bin_op_expr(expr, op_prec)
else:
# The operation cannot be associated with the
# current left operand due to the precedence
Expand Down Expand Up @@ -1454,25 +1459,50 @@ def parse_index_expr(self, base: Any) -> IndexExpr:

def parse_bin_op_expr(self, left: Node, prec: int) -> OpExpr:
op = self.expect_type(Op)
op2 = none
op_str = op.string
if op_str == 'not':
if self.current_str() == 'in':
op_str = 'not in'
op2 = self.skip()
else:
self.parse_error()
elif op_str == 'is' and self.current_str() == 'not':
op_str = 'is not'
op2 = self.skip()
elif op_str == '~':
if op_str == '~':
self.ind -= 1
self.parse_error()
right = self.parse_expression(prec)
node = OpExpr(op_str, left, right)
self.set_repr(node, noderepr.OpExprRepr(op, op2))
self.set_repr(node, noderepr.OpExprRepr(op))
return node

def parse_comparison_expr(self, left: Node, prec: int) -> ComparisonExpr:
operators = [] # type: List[Tuple[Token, Token]]
operators_str = [] # type: List[str]
operands = [left]

while True:
op = self.expect_type(Op)
op2 = none
op_str = op.string
if op_str == 'not':
if self.current_str() == 'in':
op_str = 'not in'
op2 = self.skip()
else:
self.parse_error()
elif op_str == 'is' and self.current_str() == 'not':
op_str = 'is not'
op2 = self.skip()

operators_str.append(op_str)
operators.append( (op, op2) )
operand = self.parse_expression(prec)
operands.append(operand)

# Continue if next token is a comparison operator
t = self.current()
s = self.current_str()
if s not in op_comp:
break

node = ComparisonExpr(operators_str, operands)
self.set_repr(node, noderepr.ComparisonExprRepr(operators))
return node


def parse_unary_expr(self) -> UnaryExpr:
op_tok = self.skip()
op = op_tok.string
Expand Down
6 changes: 6 additions & 0 deletions mypy/pprinter.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,12 @@ def visit_op_expr(self, o):
self.string(' %s ' % o.op)
self.node(o.right)

def visit_comparison_expr(self, o):
self.node(o.operands[0])
for operator, operand in zip(o.operators, o.operands[1:]):
self.string(' %s ' % operator)
self.node(operand)

def visit_unary_expr(self, o):
self.string(o.op)
if o.op == 'not':
Expand Down
6 changes: 5 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr,
FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr,
StrExpr, PrintStmt, ConditionalExpr, DucktypeExpr, DisjointclassExpr,
ARG_POS, ARG_NAMED, MroError, type_aliases
ComparisonExpr, ARG_POS, ARG_NAMED, MroError, type_aliases
)
from mypy.visitor import NodeVisitor
from mypy.traverser import TraverserVisitor
Expand Down Expand Up @@ -1243,6 +1243,10 @@ def visit_op_expr(self, expr: OpExpr) -> None:
expr.left.accept(self)
expr.right.accept(self)

def visit_comparison_expr(self, expr: ComparisonExpr) -> None:
for operand in expr.operands:
operand.accept(self)

def visit_unary_expr(self, expr: UnaryExpr) -> None:
expr.expr.accept(self)

Expand Down
6 changes: 5 additions & 1 deletion mypy/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from mypy import nodes
from mypy.nodes import (
Node, FuncDef, TypeApplication, AssignmentStmt, NameExpr, CallExpr,
MemberExpr, OpExpr, IndexExpr, UnaryExpr
MemberExpr, OpExpr, ComparisonExpr, IndexExpr, UnaryExpr
)


Expand Down Expand Up @@ -126,6 +126,10 @@ def visit_op_expr(self, o: OpExpr) -> None:
self.process_node(o)
super().visit_op_expr(o)

def visit_comparison_expr(self, o: ComparisonExpr) -> None:
self.process_node(o)
super().visit_comparison_expr(o)

def visit_index_expr(self, o: IndexExpr) -> None:
self.process_node(o)
super().visit_index_expr(o)
Expand Down
3 changes: 3 additions & 0 deletions mypy/strconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ def visit_call_expr(self, o):
def visit_op_expr(self, o):
return self.dump([o.op, o.left, o.right], o)

def visit_comparison_expr(self, o):
return self.dump([o.operators, o.operands], o)

def visit_cast_expr(self, o):
return self.dump([o.expr, o.type], o)

Expand Down
Loading