diff --git a/mypy/solve.py b/mypy/solve.py index 52e6549e98a6..4d0ca6b7af24 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -109,6 +109,13 @@ def solve_constraints( else: candidate = AnyType(TypeOfAny.special_form) res.append(candidate) + + if not free_vars: + # Most of the validation for solutions is done in applytype.py, but here we can + # quickly test solutions w.r.t. to upper bounds, and use the latter (if possible), + # if solutions are actually not valid (due to poor inference context). + res = pre_validate_solutions(res, original_vars, constraints) + return res, free_vars @@ -487,3 +494,31 @@ def check_linear(scc: set[TypeVarId], lowers: Bounds, uppers: Bounds) -> bool: def get_vars(target: Type, vars: list[TypeVarId]) -> set[TypeVarId]: """Find type variables for which we are solving in a target type.""" return {tv.id for tv in get_all_type_vars(target)} & set(vars) + + +def pre_validate_solutions( + solutions: list[Type | None], + original_vars: Sequence[TypeVarLikeType], + constraints: list[Constraint], +) -> list[Type | None]: + """Check is each solution satisfies the upper bound of the corresponding type variable. + + If it doesn't satisfy the bound, check if bound itself satisfies all constraints, and + if yes, use it instead as a fallback solution. + """ + new_solutions: list[Type | None] = [] + for t, s in zip(original_vars, solutions): + if s is not None and not is_subtype(s, t.upper_bound): + bound_satisfies_all = True + for c in constraints: + if c.op == SUBTYPE_OF and not is_subtype(t.upper_bound, c.target): + bound_satisfies_all = False + break + if c.op == SUPERTYPE_OF and not is_subtype(c.target, t.upper_bound): + bound_satisfies_all = False + break + if bound_satisfies_all: + new_solutions.append(t.upper_bound) + continue + new_solutions.append(s) + return new_solutions diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index caa44cb40ad4..348eb8b60076 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3542,6 +3542,14 @@ T = TypeVar("T") def type_or_callable(value: T, tp: Union[Type[T], Callable[[int], T]]) -> T: ... reveal_type(type_or_callable(A("test"), A)) # N: Revealed type is "__main__.A" +[case testUpperBoundAsInferenceFallback] +from typing import Callable, TypeVar, Any, Mapping, Optional +T = TypeVar("T", bound=Mapping[str, Any]) +def raises(opts: Optional[T]) -> T: pass +def assertRaises(cb: Callable[..., object]) -> None: pass +assertRaises(raises) # OK +[builtins fixtures/dict.pyi] + [case testJoinWithAnyFallback] from unknown import X # type: ignore[import]