diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index 8a9972bd6b289..3eb267638ab1a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -326,7 +326,7 @@ impl<'db> Type<'db> { #[must_use] pub fn call(&self, db: &'db dyn Db) -> Option> { match self { - Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)), + Type::Function(function_type) => Some(function_type.return_type(db)), // TODO annotated return type on `__new__` or metaclass `__call__` Type::Class(class) => Some(Type::Instance(*class)), @@ -374,21 +374,34 @@ impl<'db> FunctionType<'db> { self.decorators(db).contains(&decorator) } - /// annotated return type for this function, if any - pub fn returns(&self, db: &'db dyn Db) -> Option> { + /// inferred return type for this function + pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> { let definition = self.definition(db); let DefinitionKind::Function(function_stmt_node) = definition.node(db) else { panic!("Function type definition must have `DefinitionKind::Function`") }; - function_stmt_node.returns.as_ref().map(|returns| { - if function_stmt_node.is_async { - // TODO: generic `types.CoroutineType`! - Type::Unknown - } else { - definition_expression_ty(db, definition, returns.as_ref()) - } - }) + // TODO if a function `bar` is decorated by `foo`, + // where `foo` is annotated as returning a type `X` that is a subtype of `Callable`, + // we need to infer the return type from `X`'s return annotation + // rather than from `bar`'s return annotation + // in order to determine the type that `bar` returns + if !function_stmt_node.decorator_list.is_empty() { + return Type::Unknown; + } + + function_stmt_node + .returns + .as_ref() + .map(|returns| { + if function_stmt_node.is_async { + // TODO: generic `types.CoroutineType`! + Type::Unknown + } else { + definition_expression_ty(db, definition, returns.as_ref()) + } + }) + .unwrap_or(Type::Unknown) } } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 2cd8ac2ed821d..5d0abf46b73d5 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -2806,10 +2806,7 @@ mod tests { panic!("example is not a function"); }; - let returns = function - .returns(&db) - .expect("There is a return type on the function"); - + let returns = function.return_type(&db); assert_eq!(returns.display(&db).to_string(), "int"); Ok(()) @@ -2854,6 +2851,35 @@ mod tests { Ok(()) } + #[test] + fn basic_decorated_call_expression() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + " + from typing import Callable + + def foo() -> int: + return 42 + + def decorator(func) -> Callable[[], int]: + return foo + + @decorator + def bar() -> str: + return 'bar' + + x = bar() + ", + )?; + + // TODO: should be `int`! + assert_public_ty(&db, "src/a.py", "x", "Unknown"); + + Ok(()) + } + #[test] fn class_constructor_call_expression() -> anyhow::Result<()> { let mut db = setup_db();