diff --git a/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md new file mode 100644 index 0000000000000..7eb5298d908f7 --- /dev/null +++ b/crates/red_knot_python_semantic/resources/mdtest/narrow/isinstance.md @@ -0,0 +1,155 @@ +# Narrowing for `isinstance` checks + +Narrowing for `isinstance(object, classinfo)` expressions. + +## `classinfo` is a single type + +```py +x = 1 if flag else "a" + +if isinstance(x, int): + reveal_type(x) # revealed: Literal[1] + +if isinstance(x, str): + reveal_type(x) # revealed: Literal["a"] + if isinstance(x, int): + reveal_type(x) # revealed: Never + +if isinstance(x, (int, object)): + reveal_type(x) # revealed: Literal[1] | Literal["a"] +``` + +## `classinfo` is a tuple of types + +Note: `isinstance(x, (int, str))` should not be confused with +`isinstance(x, tuple[(int, str)])`. The former is equivalent to +`isinstance(x, int | str)`: + +```py +x = 1 if flag else "a" + +if isinstance(x, (int, str)): + reveal_type(x) # revealed: Literal[1] | Literal["a"] + +if isinstance(x, (int, bytes)): + reveal_type(x) # revealed: Literal[1] + +if isinstance(x, (bytes, str)): + reveal_type(x) # revealed: Literal["a"] + +# No narrowing should occur if a larger type is also +# one of the possibilities: +if isinstance(x, (int, object)): + reveal_type(x) # revealed: Literal[1] | Literal["a"] + +y = 1 if flag1 else "a" if flag2 else b"b" +if isinstance(y, (int, str)): + reveal_type(y) # revealed: Literal[1] | Literal["a"] + +if isinstance(y, (int, bytes)): + reveal_type(y) # revealed: Literal[1] | Literal[b"b"] + +if isinstance(y, (str, bytes)): + reveal_type(y) # revealed: Literal["a"] | Literal[b"b"] +``` + +## `classinfo` is a nested tuple of types + +```py +x = 1 if flag else "a" + +if isinstance(x, (bool, (bytes, int))): + reveal_type(x) # revealed: Literal[1] +``` + +## Class types + +```py +class A: ... + + +class B: ... + + +def get_object() -> object: ... + + +x = get_object() + +if isinstance(x, A): + reveal_type(x) # revealed: A + if isinstance(x, B): + reveal_type(x) # revealed: A & B +``` + +## No narrowing for instances of `builtins.type` + +```py +t = type("t", (), {}) + +# This isn't testing what we want it to test if we infer anything more precise here: +reveal_type(t) # revealed: type +x = 1 if flag else "foo" + +if isinstance(x, t): + reveal_type(x) # revealed: Literal[1] | Literal["foo"] +``` + +## Do not use custom `isinstance` for narrowing + +```py +def isinstance(x, t): + return True + + +x = 1 if flag else "a" +if isinstance(x, int): + reveal_type(x) # revealed: Literal[1] | Literal["a"] +``` + +## Do support narrowing if `isinstance` is aliased + +```py +isinstance_alias = isinstance + +x = 1 if flag else "a" +if isinstance_alias(x, int): + reveal_type(x) # revealed: Literal[1] +``` + +## Do support narrowing if `isinstance` is imported + +```py +from builtins import isinstance as imported_isinstance + +x = 1 if flag else "a" +if imported_isinstance(x, int): + reveal_type(x) # revealed: Literal[1] +``` + +## Do not narrow if second argument is not a type + +```py +x = 1 if flag else "a" + +# TODO: this should cause us to emit a diagnostic during +# type checking +if isinstance(x, "a"): + reveal_type(x) # revealed: Literal[1] | Literal["a"] + +# TODO: this should cause us to emit a diagnostic during +# type checking +if isinstance(x, "int"): + reveal_type(x) # revealed: Literal[1] | Literal["a"] +``` + +## Do not narrow if there are keyword arguments + +```py +x = 1 if flag else "a" + +# TODO: this should cause us to emit a diagnostic +# (`isinstance` has no `foo` parameter) +if isinstance(x, int, foo="bar"): + reveal_type(x) # revealed: Literal[1] | Literal["a"] +``` diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 20a6647b3c7a7..3b1b284ad7003 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -47,7 +47,13 @@ impl<'db> Definition<'db> { self.kind(db).category().is_binding() } - /// Return true if this is a symbol was defined in the `typing` or `typing_extensions` modules + pub(crate) fn is_builtin_definition(self, db: &'db dyn Db) -> bool { + file_to_module(db, self.file(db)).is_some_and(|module| { + module.search_path().is_standard_library() && matches!(&**module.name(), "builtins") + }) + } + + /// Return true if this symbol was defined in the `typing` or `typing_extensions` modules pub(crate) fn is_typing_definition(self, db: &'db dyn Db) -> bool { file_to_module(db, self.file(db)).is_some_and(|module| { module.search_path().is_standard_library() diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 57fa1c8071fc8..b953f2865119f 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -868,13 +868,16 @@ impl<'db> Type<'db> { fn call(self, db: &'db dyn Db, arg_types: &[Type<'db>]) -> CallOutcome<'db> { match self { // TODO validate typed call arguments vs callable signature - Type::FunctionLiteral(function_type) => match function_type.known(db) { - None => CallOutcome::callable(function_type.return_type(db)), - Some(KnownFunction::RevealType) => CallOutcome::revealed( - function_type.return_type(db), - *arg_types.first().unwrap_or(&Type::Unknown), - ), - }, + Type::FunctionLiteral(function_type) => { + if function_type.is_known(db, KnownFunction::RevealType) { + CallOutcome::revealed( + function_type.return_type(db), + *arg_types.first().unwrap_or(&Type::Unknown), + ) + } else { + CallOutcome::callable(function_type.return_type(db)) + } + } // TODO annotated return type on `__new__` or metaclass `__call__` Type::ClassLiteral(class) => { @@ -1595,6 +1598,10 @@ impl<'db> FunctionType<'db> { }) .unwrap_or(Type::Unknown) } + + pub fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool { + self.known(db) == Some(known_function) + } } /// Non-exhaustive enumeration of known functions (e.g. `builtins.reveal_type`, ...) that might @@ -1603,6 +1610,8 @@ impl<'db> FunctionType<'db> { pub enum KnownFunction { /// `builtins.reveal_type`, `typing.reveal_type` or `typing_extensions.reveal_type` RevealType, + /// `builtins.isinstance` + IsInstance, } #[salsa::interned] diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 9d191a47eba3d..4f4e14aa4bea2 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -779,6 +779,9 @@ impl<'db> TypeInferenceBuilder<'db> { "reveal_type" if definition.is_typing_definition(self.db) => { Some(KnownFunction::RevealType) } + "isinstance" if definition.is_builtin_definition(self.db) => { + Some(KnownFunction::IsInstance) + } _ => None, }; let function_ty = Type::FunctionLiteral(FunctionType::new( diff --git a/crates/red_knot_python_semantic/src/types/narrow.rs b/crates/red_knot_python_semantic/src/types/narrow.rs index d2c60a4ebc686..93d9cb43bfa92 100644 --- a/crates/red_knot_python_semantic/src/types/narrow.rs +++ b/crates/red_knot_python_semantic/src/types/narrow.rs @@ -4,7 +4,9 @@ use crate::semantic_index::definition::Definition; use crate::semantic_index::expression::Expression; use crate::semantic_index::symbol::{ScopeId, ScopedSymbolId, SymbolTable}; use crate::semantic_index::symbol_table; -use crate::types::{infer_expression_types, IntersectionBuilder, Type}; +use crate::types::{ + infer_expression_types, IntersectionBuilder, KnownFunction, Type, UnionBuilder, +}; use crate::Db; use itertools::Itertools; use ruff_python_ast as ast; @@ -60,6 +62,28 @@ fn all_narrowing_constraints_for_expression<'db>( NarrowingConstraintsBuilder::new(db, Constraint::Expression(expression)).finish() } +/// Generate a constraint from the *type* of the second argument of an `isinstance` call. +/// +/// Example: for `isinstance(…, str)`, we would infer `Type::ClassLiteral(str)` from the +/// second argument, but we need to generate a `Type::Instance(str)` constraint that can +/// be used to narrow down the type of the first argument. +fn generate_isinstance_constraint<'db>( + db: &'db dyn Db, + classinfo: &Type<'db>, +) -> Option> { + match classinfo { + Type::ClassLiteral(class) => Some(Type::Instance(*class)), + Type::Tuple(tuple) => { + let mut builder = UnionBuilder::new(db); + for element in tuple.elements(db) { + builder = builder.add(generate_isinstance_constraint(db, element)?); + } + Some(builder.build()) + } + _ => None, + } +} + type NarrowingConstraints<'db> = FxHashMap>; struct NarrowingConstraintsBuilder<'db> { @@ -88,10 +112,15 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } fn evaluate_expression_constraint(&mut self, expression: Expression<'db>) { - if let ast::Expr::Compare(expr_compare) = expression.node_ref(self.db).node() { - self.add_expr_compare(expr_compare, expression); + match expression.node_ref(self.db).node() { + ast::Expr::Compare(expr_compare) => { + self.add_expr_compare(expr_compare, expression); + } + ast::Expr::Call(expr_call) => { + self.add_expr_call(expr_call, expression); + } + _ => {} // TODO other test expression kinds } - // TODO other test expression kinds } fn evaluate_pattern_constraint(&mut self, pattern: PatternConstraint<'db>) { @@ -194,6 +223,33 @@ impl<'db> NarrowingConstraintsBuilder<'db> { } } + fn add_expr_call(&mut self, expr_call: &ast::ExprCall, expression: Expression<'db>) { + let scope = self.scope(); + let inference = infer_expression_types(self.db, expression); + + if let Some(func_type) = inference + .expression_ty(expr_call.func.scoped_ast_id(self.db, scope)) + .into_function_literal_type() + { + if func_type.is_known(self.db, KnownFunction::IsInstance) + && expr_call.arguments.keywords.is_empty() + { + if let [ast::Expr::Name(ast::ExprName { id, .. }), rhs] = &*expr_call.arguments.args + { + let symbol = self.symbols().symbol_id_by_name(id).unwrap(); + + let rhs_type = inference.expression_ty(rhs.scoped_ast_id(self.db, scope)); + + // TODO: add support for PEP 604 union types on the right hand side: + // isinstance(x, str | (int | float)) + if let Some(constraint) = generate_isinstance_constraint(self.db, &rhs_type) { + self.constraints.insert(symbol, constraint); + } + } + } + } + } + fn add_match_pattern_singleton( &mut self, subject: &ast::Expr,