Skip to content

Commit

Permalink
[red-knot] Fix call expression inference edge case for decorated func…
Browse files Browse the repository at this point in the history
…tions (#13191)
  • Loading branch information
AlexWaygood authored Sep 1, 2024
1 parent 5661353 commit 2014cba
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 15 deletions.
35 changes: 24 additions & 11 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ impl<'db> Type<'db> {
#[must_use]
pub fn call(&self, db: &'db dyn Db) -> Option<Type<'db>> {
match self {
Type::Function(function_type) => function_type.returns(db).or(Some(Type::Unknown)),
Type::Function(function_type) => Some(function_type.return_type(db)),

// TODO annotated return type on `__new__` or metaclass `__call__`
Type::Class(class) => Some(Type::Instance(*class)),
Expand Down Expand Up @@ -374,21 +374,34 @@ impl<'db> FunctionType<'db> {
self.decorators(db).contains(&decorator)
}

/// annotated return type for this function, if any
pub fn returns(&self, db: &'db dyn Db) -> Option<Type<'db>> {
/// inferred return type for this function
pub fn return_type(&self, db: &'db dyn Db) -> Type<'db> {
let definition = self.definition(db);
let DefinitionKind::Function(function_stmt_node) = definition.node(db) else {
panic!("Function type definition must have `DefinitionKind::Function`")
};

function_stmt_node.returns.as_ref().map(|returns| {
if function_stmt_node.is_async {
// TODO: generic `types.CoroutineType`!
Type::Unknown
} else {
definition_expression_ty(db, definition, returns.as_ref())
}
})
// TODO if a function `bar` is decorated by `foo`,
// where `foo` is annotated as returning a type `X` that is a subtype of `Callable`,
// we need to infer the return type from `X`'s return annotation
// rather than from `bar`'s return annotation
// in order to determine the type that `bar` returns
if !function_stmt_node.decorator_list.is_empty() {
return Type::Unknown;
}

function_stmt_node
.returns
.as_ref()
.map(|returns| {
if function_stmt_node.is_async {
// TODO: generic `types.CoroutineType`!
Type::Unknown
} else {
definition_expression_ty(db, definition, returns.as_ref())
}
})
.unwrap_or(Type::Unknown)
}
}

Expand Down
34 changes: 30 additions & 4 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2806,10 +2806,7 @@ mod tests {
panic!("example is not a function");
};

let returns = function
.returns(&db)
.expect("There is a return type on the function");

let returns = function.return_type(&db);
assert_eq!(returns.display(&db).to_string(), "int");

Ok(())
Expand Down Expand Up @@ -2854,6 +2851,35 @@ mod tests {
Ok(())
}

#[test]
fn basic_decorated_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
"
from typing import Callable
def foo() -> int:
return 42
def decorator(func) -> Callable[[], int]:
return foo
@decorator
def bar() -> str:
return 'bar'
x = bar()
",
)?;

// TODO: should be `int`!
assert_public_ty(&db, "src/a.py", "x", "Unknown");

Ok(())
}

#[test]
fn class_constructor_call_expression() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down

0 comments on commit 2014cba

Please sign in to comment.