Skip to content
12 changes: 6 additions & 6 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ def check_call(self, callee: Type, args: List[Node],
callee)
elif isinstance(callee, Instance):
call_function = analyze_member_access('__call__', callee, context,
False, False, self.named_type, self.not_ready_callback,
self.msg)
False, False, False, self.named_type,
self.not_ready_callback, self.msg)
return self.check_call(call_function, args, arg_kinds, context, arg_names,
callable_node, arg_messages)
elif isinstance(callee, TypeVarType):
Expand Down Expand Up @@ -861,7 +861,7 @@ def analyze_ordinary_member_access(self, e: MemberExpr,
else:
# This is a reference to a non-module attribute.
return analyze_member_access(e.name, self.accept(e.expr), e,
is_lvalue, False,
is_lvalue, False, False,
self.named_type, self.not_ready_callback, self.msg)

def analyze_external_member_access(self, member: str, base_type: Type,
Expand All @@ -870,7 +870,7 @@ def analyze_external_member_access(self, member: str, base_type: Type,
refer to private definitions. Return the result type.
"""
# TODO remove; no private definitions in mypy
return analyze_member_access(member, base_type, context, False, False,
return analyze_member_access(member, base_type, context, False, False, False,
self.named_type, self.not_ready_callback, self.msg)

def visit_int_expr(self, e: IntExpr) -> Type:
Expand Down Expand Up @@ -1008,7 +1008,7 @@ def check_op_local(self, method: str, base_type: Type, arg: Node,

Return tuple (result type, inferred operator method type).
"""
method_type = analyze_member_access(method, base_type, context, False, False,
method_type = analyze_member_access(method, base_type, context, False, False, True,
self.named_type, self.not_ready_callback, local_errors)
return self.check_call(method_type, [arg], [nodes.ARG_POS],
context, arg_messages=local_errors)
Expand Down Expand Up @@ -1434,7 +1434,7 @@ def analyze_super(self, e: SuperExpr, is_lvalue: bool) -> Type:
if not self.chk.typing_mode_full():
return AnyType()
return analyze_member_access(e.name, self_type(e.info), e,
is_lvalue, True,
is_lvalue, True, False,
self.named_type, self.not_ready_callback,
self.msg, base)
else:
Expand Down
47 changes: 32 additions & 15 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
DeletedType, NoneTyp, TypeType
)
from mypy.nodes import TypeInfo, FuncBase, Var, FuncDef, SymbolNode, Context
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, function_type, Decorator, OverloadedFuncDef
from mypy.nodes import ARG_POS, ARG_STAR, ARG_STAR2, OpExpr, ComparisonExpr
from mypy.nodes import function_type, Decorator, OverloadedFuncDef
from mypy.messages import MessageBuilder
from mypy.maptype import map_instance_to_supertype
from mypy.expandtype import expand_type_by_instance
Expand All @@ -23,6 +24,7 @@ def analyze_member_access(name: str,
node: Context,
is_lvalue: bool,
is_super: bool,
is_operator: bool,
builtin_type: Callable[[str], Instance],
not_ready_callback: Callable[[str, Context], None],
msg: MessageBuilder,
Expand Down Expand Up @@ -79,45 +81,59 @@ def analyze_member_access(name: str,
elif isinstance(typ, NoneTyp):
# The only attribute NoneType has are those it inherits from object
return analyze_member_access(name, builtin_type('builtins.object'), node, is_lvalue,
is_super, builtin_type, not_ready_callback, msg,
is_super, is_operator, builtin_type, not_ready_callback, msg,
report_type=report_type)
elif isinstance(typ, UnionType):
# The base object has dynamic type.
msg.disable_type_names += 1
results = [analyze_member_access(name, subtype, node, is_lvalue,
is_super, builtin_type, not_ready_callback, msg)
results = [analyze_member_access(name, subtype, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg)
for subtype in typ.items]
msg.disable_type_names -= 1
return UnionType.make_simplified_union(results)
elif isinstance(typ, TupleType):
# Actually look up from the fallback instance type.
return analyze_member_access(name, typ.fallback, node, is_lvalue,
is_super, builtin_type, not_ready_callback, msg)
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
is_operator, builtin_type, not_ready_callback, msg)
elif isinstance(typ, FunctionLike) and typ.is_type_obj():
# Class attribute.
# TODO super?
ret_type = typ.items()[0].ret_type
if isinstance(ret_type, TupleType):
ret_type = ret_type.fallback
if isinstance(ret_type, Instance):
result = analyze_class_attribute_access(ret_type, name, node, is_lvalue,
builtin_type, not_ready_callback, msg)
if result:
return result
if not is_operator:
# When Python sees an operator (eg `3 == 4`), it automatically translates that
# into something like `int.__eq__(3, 4)` instead of `(3).__eq__(4)` as an
# optimation.
#
# While it normally it doesn't matter which of the two versions are used, it
# does cause inconsistencies when working with classes. For example, translating
# `int == int` to `int.__eq__(int)` would not work since `int.__eq__` is meant to
# compare two int _instances_. What we really want is `type(int).__eq__`, which
# is meant to compare two types or classes.
#
# This check makes sure that when we encounter an operator, we skip looking up
# the corresponding method in the current instance to avoid this edge case.
# See https://github.com/python/mypy/pull/1787 for more info.
result = analyze_class_attribute_access(ret_type, name, node, is_lvalue,
builtin_type, not_ready_callback, msg)
if result:
return result
# Look up from the 'type' type.
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
builtin_type, not_ready_callback, msg,
is_operator, builtin_type, not_ready_callback, msg,
report_type=report_type)
else:
assert False, 'Unexpected type {}'.format(repr(ret_type))
elif isinstance(typ, FunctionLike):
# Look up from the 'function' type.
return analyze_member_access(name, typ.fallback, node, is_lvalue, is_super,
builtin_type, not_ready_callback, msg,
is_operator, builtin_type, not_ready_callback, msg,
report_type=report_type)
elif isinstance(typ, TypeVarType):
return analyze_member_access(name, typ.upper_bound, node, is_lvalue, is_super,
builtin_type, not_ready_callback, msg,
is_operator, builtin_type, not_ready_callback, msg,
report_type=report_type)
elif isinstance(typ, DeletedType):
msg.deleted_as_rvalue(typ, node)
Expand All @@ -130,14 +146,15 @@ def analyze_member_access(name: str,
elif isinstance(typ.item, TypeVarType):
if isinstance(typ.item.upper_bound, Instance):
item = typ.item.upper_bound
if item:
if item and not is_operator:
# See comment above for why operators are skipped
result = analyze_class_attribute_access(item, name, node, is_lvalue,
builtin_type, not_ready_callback, msg)
if result:
return result
fallback = builtin_type('builtins.type')
return analyze_member_access(name, fallback, node, is_lvalue, is_super,
builtin_type, not_ready_callback, msg,
is_operator, builtin_type, not_ready_callback, msg,
report_type=report_type)
return msg.has_no_attr(report_type, name, node)

Expand Down
41 changes: 41 additions & 0 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -2006,3 +2006,44 @@ reveal_type(User) # E: Revealed type is 'builtins.type'
[builtins fixtures/args.py]
[out]

[case testTypeTypeComparisonWorks]
class User: pass

User == User
User == type(User())
type(User()) == User
type(User()) == type(User())

User != User
User != type(User())
type(User()) != User
type(User()) != type(User())

int == int
int == type(3)
type(3) == int
type(3) == type(3)

int != int
int != type(3)
type(3) != int
type(3) != type(3)

User is User
User is type(User)
type(User) is User
type(User) is type(User)

int is int
int is type(3)
type(3) is int
type(3) is type(3)

int.__eq__(int)
int.__eq__(3, 4)
[builtins fixtures/args.py]
[out]
main:33: error: Too few arguments for "__eq__" of "int"
main:33: error: Unsupported operand types for == ("int" and "int")


6 changes: 5 additions & 1 deletion test-data/unit/fixtures/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,21 @@

class object:
def __init__(self) -> None: pass
def __eq__(self, o: object) -> bool: pass
def __ne__(self, o: object) -> bool: pass

class type:
@overload
def __init__(self, o: object) -> None: pass
@overload
def __init__(self, name: str, bases: Tuple[type, ...], dict: Dict[str, Any]) -> None: pass
def __call__(self, *args: Any, **kwargs: Any) -> Any: pass

class tuple(Iterable[Tco], Generic[Tco]): pass
class dict(Generic[T, S]): pass

class int: pass
class int:
def __eq__(self, o: object) -> bool: pass
class str: pass
class bool: pass
class function: pass