Skip to content

Commit 35f4a75

Browse files
committed
[red-knot] infer function's return type
1 parent fe4051b commit 35f4a75

File tree

4 files changed

+117
-14
lines changed

4 files changed

+117
-14
lines changed

crates/ty_python_semantic/resources/mdtest/function/return_type.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,43 @@ def f(cond: bool) -> int:
182182
return 2
183183
```
184184

185+
## Inferred return type
186+
187+
If a function's return type is not annotated, it is inferred. The inferred type is the union of all
188+
possible return types.
189+
190+
```py
191+
def f():
192+
return 1
193+
194+
reveal_type(f()) # revealed: Literal[1]
195+
196+
def g(cond: bool):
197+
if cond:
198+
return 1
199+
else:
200+
return "a"
201+
202+
reveal_type(g(True)) # revealed: Literal[1, "a"]
203+
204+
# This function implicitly returns `None`.
205+
def h(x: int, y: str):
206+
if x > 10:
207+
return x
208+
elif x > 5:
209+
return y
210+
211+
reveal_type(h(1, "a")) # revealed: int | str | None
212+
213+
def generator():
214+
yield 1
215+
yield 2
216+
return None
217+
218+
# TODO: Should be `Generator[Literal[1, 2], Any, None]`
219+
reveal_type(generator()) # revealed: None
220+
```
221+
185222
## Invalid return type
186223

187224
<!-- snapshot-diagnostics -->

crates/ty_python_semantic/src/types.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4179,6 +4179,14 @@ impl<'db> Type<'db> {
41794179
}
41804180
}
41814181

4182+
/// Returns the inferred return type of `self` if it is a function literal.
4183+
fn inferred_return_type(self, db: &'db dyn Db) -> Option<Type<'db>> {
4184+
match self {
4185+
Type::FunctionLiteral(function_type) => Some(function_type.inferred_return_type(db)),
4186+
_ => None,
4187+
}
4188+
}
4189+
41824190
/// Calls `self`. Returns a [`CallError`] if `self` is (always or possibly) not callable, or if
41834191
/// the arguments are not compatible with the formal parameters.
41844192
///
@@ -4191,7 +4199,9 @@ impl<'db> Type<'db> {
41914199
argument_types: &CallArgumentTypes<'_, 'db>,
41924200
) -> Result<Bindings<'db>, CallError<'db>> {
41934201
let signatures = self.signatures(db);
4194-
Bindings::match_parameters(signatures, argument_types).check_types(db, argument_types)
4202+
let inferred_return_ty = || self.inferred_return_type(db).unwrap_or(Type::unknown());
4203+
Bindings::match_parameters(signatures, inferred_return_ty, argument_types)
4204+
.check_types(db, argument_types)
41954205
}
41964206

41974207
/// Look up a dunder method on the meta-type of `self` and call it.
@@ -4229,8 +4239,14 @@ impl<'db> Type<'db> {
42294239
{
42304240
Symbol::Type(dunder_callable, boundness) => {
42314241
let signatures = dunder_callable.signatures(db);
4232-
let bindings = Bindings::match_parameters(signatures, argument_types)
4233-
.check_types(db, argument_types)?;
4242+
let inferred_return_ty = || {
4243+
dunder_callable
4244+
.inferred_return_type(db)
4245+
.unwrap_or(Type::unknown())
4246+
};
4247+
let bindings =
4248+
Bindings::match_parameters(signatures, inferred_return_ty, argument_types)
4249+
.check_types(db, argument_types)?;
42344250
if boundness == Boundness::PossiblyUnbound {
42354251
return Err(CallDunderError::PossiblyUnbound(Box::new(bindings)));
42364252
}
@@ -6689,6 +6705,13 @@ impl<'db> FunctionType<'db> {
66896705
signature
66906706
}
66916707

6708+
/// Infers this function scope's types and returns the inferred return type.
6709+
fn inferred_return_type(self, db: &'db dyn Db) -> Type<'db> {
6710+
let scope = self.body_scope(db);
6711+
let inference = infer_scope_types(db, scope);
6712+
inference.inferred_return_type(db)
6713+
}
6714+
66926715
pub(crate) fn is_known(self, db: &'db dyn Db, known_function: KnownFunction) -> bool {
66936716
self.known(db) == Some(known_function)
66946717
}

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ impl<'db> Bindings<'db> {
5656
/// verify that each argument type is assignable to the corresponding parameter type.
5757
pub(crate) fn match_parameters(
5858
signatures: Signatures<'db>,
59+
inferred_return_ty: impl Fn() -> Type<'db> + Copy,
5960
arguments: &CallArguments<'_>,
6061
) -> Self {
6162
let mut argument_forms = vec![None; arguments.len()];
@@ -68,6 +69,7 @@ impl<'db> Bindings<'db> {
6869
arguments,
6970
&mut argument_forms,
7071
&mut conflicting_forms,
72+
inferred_return_ty,
7173
)
7274
})
7375
.collect();
@@ -892,6 +894,7 @@ impl<'db> CallableBinding<'db> {
892894
arguments: &CallArguments<'_>,
893895
argument_forms: &mut [Option<ParameterForm>],
894896
conflicting_forms: &mut [bool],
897+
inferred_return_ty: impl Fn() -> Type<'db> + Copy,
895898
) -> Self {
896899
// If this callable is a bound method, prepend the self instance onto the arguments list
897900
// before checking.
@@ -911,6 +914,7 @@ impl<'db> CallableBinding<'db> {
911914
arguments.as_ref(),
912915
argument_forms,
913916
conflicting_forms,
917+
inferred_return_ty,
914918
)
915919
})
916920
.collect();
@@ -1079,6 +1083,7 @@ impl<'db> Binding<'db> {
10791083
arguments: &CallArguments<'_>,
10801084
argument_forms: &mut [Option<ParameterForm>],
10811085
conflicting_forms: &mut [bool],
1086+
inferred_return_ty: impl Fn() -> Type<'db>,
10821087
) -> Self {
10831088
let parameters = signature.parameters();
10841089
// The parameter that each argument is matched with.
@@ -1195,7 +1200,7 @@ impl<'db> Binding<'db> {
11951200
}
11961201

11971202
Self {
1198-
return_ty: signature.return_ty.unwrap_or(Type::unknown()),
1203+
return_ty: signature.return_ty.unwrap_or_else(inferred_return_ty),
11991204
specialization: None,
12001205
inherited_specialization: None,
12011206
argument_parameters: argument_parameters.into_boxed_slice(),

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ use crate::semantic_index::expression::{Expression, ExpressionKind};
5757
use crate::semantic_index::symbol::{
5858
FileScopeId, NodeWithScopeKind, NodeWithScopeRef, ScopeId, ScopeKind,
5959
};
60-
use crate::semantic_index::{semantic_index, EagerBindingsResult, SemanticIndex};
60+
use crate::semantic_index::{semantic_index, use_def_map, EagerBindingsResult, SemanticIndex};
6161
use crate::symbol::{
6262
builtins_module_scope, builtins_symbol, explicit_global_symbol,
6363
module_type_implicit_global_symbol, symbol, symbol_from_bindings, symbol_from_declarations,
@@ -366,6 +366,11 @@ pub(crate) struct TypeInference<'db> {
366366
/// The scope this region is part of.
367367
scope: ScopeId<'db>,
368368

369+
/// The returned types of this region (if this is a function body).
370+
///
371+
/// These are stored in `Vec` to delay the creation of the union type as long as possible.
372+
return_types: Vec<Type<'db>>,
373+
369374
/// The fallback type for missing expressions/bindings/declarations.
370375
///
371376
/// This is used only when constructing a cycle-recovery `TypeInference`.
@@ -381,6 +386,7 @@ impl<'db> TypeInference<'db> {
381386
deferred: FxHashSet::default(),
382387
diagnostics: TypeCheckDiagnostics::default(),
383388
scope,
389+
return_types: vec![],
384390
cycle_fallback_type: None,
385391
}
386392
}
@@ -393,6 +399,7 @@ impl<'db> TypeInference<'db> {
393399
deferred: FxHashSet::default(),
394400
diagnostics: TypeCheckDiagnostics::default(),
395401
scope,
402+
return_types: vec![],
396403
cycle_fallback_type: Some(cycle_fallback_type),
397404
}
398405
}
@@ -440,12 +447,27 @@ impl<'db> TypeInference<'db> {
440447
&self.diagnostics
441448
}
442449

450+
/// Returns the inferred return type of this function body (union of all possible return types),
451+
/// or `None` if the region is not a function body.
452+
pub(crate) fn inferred_return_type(&self, db: &'db dyn Db) -> Type<'db> {
453+
let mut union = UnionBuilder::new(db);
454+
for ty in &self.return_types {
455+
union = union.add(*ty);
456+
}
457+
let use_def = use_def_map(db, self.scope);
458+
if use_def.can_implicit_return(db) {
459+
union = union.add(Type::none(db));
460+
}
461+
union.build()
462+
}
463+
443464
fn shrink_to_fit(&mut self) {
444465
self.expressions.shrink_to_fit();
445466
self.bindings.shrink_to_fit();
446467
self.declarations.shrink_to_fit();
447468
self.diagnostics.shrink_to_fit();
448469
self.deferred.shrink_to_fit();
470+
self.return_types.shrink_to_fit();
449471
}
450472
}
451473

@@ -4674,7 +4696,12 @@ impl<'db> TypeInferenceBuilder<'db> {
46744696
}
46754697

46764698
let signatures = callable_type.signatures(self.db());
4677-
let bindings = Bindings::match_parameters(signatures, &call_arguments);
4699+
let inferred_return_ty = || {
4700+
callable_type
4701+
.inferred_return_type(self.db())
4702+
.unwrap_or(Type::unknown())
4703+
};
4704+
let bindings = Bindings::match_parameters(signatures, inferred_return_ty, &call_arguments);
46784705
let call_argument_types =
46794706
self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms);
46804707

@@ -6689,15 +6716,21 @@ impl<'db> TypeInferenceBuilder<'db> {
66896716
value_ty,
66906717
generic_context.signature(self.db()),
66916718
));
6692-
let bindings = match Bindings::match_parameters(signatures, &call_argument_types)
6693-
.check_types(self.db(), &call_argument_types)
6694-
{
6695-
Ok(bindings) => bindings,
6696-
Err(CallError(_, bindings)) => {
6697-
bindings.report_diagnostics(&self.context, subscript.into());
6698-
return Type::unknown();
6699-
}
6719+
let inferred_return_ty = || {
6720+
value_ty
6721+
.inferred_return_type(self.db())
6722+
.unwrap_or(Type::unknown())
67006723
};
6724+
let bindings =
6725+
match Bindings::match_parameters(signatures, inferred_return_ty, &call_argument_types)
6726+
.check_types(self.db(), &call_argument_types)
6727+
{
6728+
Ok(bindings) => bindings,
6729+
Err(CallError(_, bindings)) => {
6730+
bindings.report_diagnostics(&self.context, subscript.into());
6731+
return Type::unknown();
6732+
}
6733+
};
67016734
let callable = bindings
67026735
.into_iter()
67036736
.next()
@@ -7105,6 +7138,11 @@ impl<'db> TypeInferenceBuilder<'db> {
71057138
self.infer_region();
71067139
self.types.diagnostics = self.context.finish();
71077140
self.types.shrink_to_fit();
7141+
self.types.return_types = self
7142+
.return_types_and_ranges
7143+
.into_iter()
7144+
.map(|ty_range| ty_range.ty)
7145+
.collect();
71087146
self.types
71097147
}
71107148
}

0 commit comments

Comments
 (0)