Skip to content

Commit 7693df9

Browse files
committed
[ty] Filter overloads based on Any / Unknown
1 parent 7342d3c commit 7693df9

File tree

2 files changed

+137
-11
lines changed

2 files changed

+137
-11
lines changed

crates/ty_python_semantic/src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5766,9 +5766,9 @@ impl<'db> KnownInstanceType<'db> {
57665766

57675767
#[derive(Copy, Clone, Debug, Eq, Hash, PartialEq)]
57685768
pub enum DynamicType {
5769-
// An explicitly annotated `typing.Any`
5769+
/// An explicitly annotated `typing.Any`
57705770
Any,
5771-
// An unannotated value, or a dynamic type resulting from an error
5771+
/// An unannotated value, or a dynamic type resulting from an error
57725772
Unknown,
57735773
/// Temporary type for symbols that can't be inferred yet because of missing implementations.
57745774
///

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 135 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,7 +1025,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
10251025
signature_type,
10261026
dunder_call_is_possibly_unbound: false,
10271027
bound_type: None,
1028-
return_type: None,
1028+
overload_call_return_type: None,
10291029
overloads: smallvec![from],
10301030
};
10311031
Bindings {
@@ -1070,7 +1070,7 @@ pub(crate) struct CallableBinding<'db> {
10701070
/// performed, and one of the expansion evaluated successfully for all of the argument lists.
10711071
/// This type is then the union of all the return types of the matched overloads for the
10721072
/// expanded argument lists.
1073-
return_type: Option<Type<'db>>,
1073+
overload_call_return_type: Option<OverloadCallReturnType<'db>>,
10741074

10751075
/// The bindings of each overload of this callable. Will be empty if the type is not callable.
10761076
///
@@ -1093,7 +1093,7 @@ impl<'db> CallableBinding<'db> {
10931093
signature_type,
10941094
dunder_call_is_possibly_unbound: false,
10951095
bound_type: None,
1096-
return_type: None,
1096+
overload_call_return_type: None,
10971097
overloads,
10981098
}
10991099
}
@@ -1104,7 +1104,7 @@ impl<'db> CallableBinding<'db> {
11041104
signature_type,
11051105
dunder_call_is_possibly_unbound: false,
11061106
bound_type: None,
1107-
return_type: None,
1107+
overload_call_return_type: None,
11081108
overloads: smallvec![],
11091109
}
11101110
}
@@ -1192,9 +1192,18 @@ impl<'db> CallableBinding<'db> {
11921192
// If only one overload evaluates without error, it is the winning match.
11931193
return;
11941194
}
1195-
MatchingOverloadIndex::Multiple(_) => {
1195+
MatchingOverloadIndex::Multiple(indexes) => {
11961196
// If two or more candidate overloads remain, proceed to step 4.
1197-
// TODO: Step 4 and Step 5 goes here...
1197+
tracing::info!(
1198+
"Multiple overloads match: {:?}, filtering based on Any",
1199+
indexes
1200+
);
1201+
1202+
// TODO: Step 4
1203+
1204+
// Step 5
1205+
self.filter_overloads_using_any_or_unknown(db, argument_types.types(), &indexes);
1206+
11981207
// We're returning here because this shouldn't lead to argument type expansion.
11991208
return;
12001209
}
@@ -1273,7 +1282,10 @@ impl<'db> CallableBinding<'db> {
12731282
// If the number of return types is equal to the number of expanded argument lists,
12741283
// they all evaluated successfully. So, we need to combine their return types by
12751284
// union to determine the final return type.
1276-
self.return_type = Some(UnionType::from_elements(db, return_types));
1285+
self.overload_call_return_type =
1286+
Some(OverloadCallReturnType::ArgumentTypeExpansion(
1287+
UnionType::from_elements(db, return_types),
1288+
));
12771289

12781290
// Restore the bindings state to the one that merges the bindings state evaluating
12791291
// each of the expanded argument list.
@@ -1292,6 +1304,99 @@ impl<'db> CallableBinding<'db> {
12921304
snapshotter.restore(self, post_evaluation_snapshot);
12931305
}
12941306

1307+
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
1308+
///
1309+
/// This is the step 5 of the [overload call evaluation algorithm][1].
1310+
///
1311+
/// The filtering works on the remaining overloads that are present at the
1312+
/// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads
1313+
/// using the [`mark_as_unmatched_overload`] method.
1314+
///
1315+
/// [`Any`]: crate::types::DynamicType::Any
1316+
/// [`Unknown`]: crate::types::DynamicType::Unknown
1317+
/// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload
1318+
/// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
1319+
fn filter_overloads_using_any_or_unknown(
1320+
&mut self,
1321+
db: &'db dyn Db,
1322+
argument_types: &[Type<'db>],
1323+
matching_overload_indexes: &[usize],
1324+
) {
1325+
let top_materialized_argument_type = TupleType::from_elements(
1326+
db,
1327+
argument_types.iter().map(|argument_type| {
1328+
argument_type.top_materialization(db, TypeVarVariance::Covariant)
1329+
}),
1330+
);
1331+
1332+
// A flag to indicate whether we've found the overload that makes the remaining overloads
1333+
// unmatched for the given argument types.
1334+
let mut filter_remaining_overloads = false;
1335+
1336+
for (upto, current_index) in matching_overload_indexes.iter().enumerate() {
1337+
if filter_remaining_overloads {
1338+
self.overloads[*current_index].mark_as_unmatched_overload();
1339+
continue;
1340+
}
1341+
let mut unions = Vec::with_capacity(argument_types.len());
1342+
for argument_index in 0..argument_types.len() {
1343+
let mut union = vec![];
1344+
for overload_index in &matching_overload_indexes[..=upto] {
1345+
let overload = &self.overloads[*overload_index];
1346+
let Some(parameter_index) = overload.argument_parameters[argument_index] else {
1347+
// There is no parameter for this argument in this overload.
1348+
continue;
1349+
};
1350+
union.push(
1351+
overload.signature.parameters()[parameter_index]
1352+
.annotated_type()
1353+
.unwrap_or(Type::unknown()),
1354+
);
1355+
}
1356+
if union.is_empty() {
1357+
continue;
1358+
}
1359+
unions.push(UnionType::from_elements(db, union));
1360+
}
1361+
if unions.len() != argument_types.len() {
1362+
continue;
1363+
}
1364+
if top_materialized_argument_type
1365+
.is_assignable_to(db, TupleType::from_elements(db, unions))
1366+
{
1367+
filter_remaining_overloads = true;
1368+
}
1369+
}
1370+
1371+
// Once this filtering process is applied for all arguments, examine the return types of
1372+
// the remaining overloads. If the resulting return types for all remaining overloads are
1373+
// equivalent, proceed to step 6.
1374+
let are_return_types_equivalent_for_all_matching_overloads = {
1375+
let mut matching_overloads = self.matching_overloads();
1376+
if let Some(first_overload_return_type) = matching_overloads
1377+
.next()
1378+
.map(|(_, overload)| overload.return_type())
1379+
{
1380+
matching_overloads.all(|(_, overload)| {
1381+
overload
1382+
.return_type()
1383+
.is_equivalent_to(db, first_overload_return_type)
1384+
})
1385+
} else {
1386+
// No matching overload
1387+
true
1388+
}
1389+
};
1390+
1391+
if !are_return_types_equivalent_for_all_matching_overloads {
1392+
// Overload matching is ambiguous.
1393+
for (_, overload) in self.matching_overloads_mut() {
1394+
overload.mark_as_unmatched_overload();
1395+
}
1396+
self.overload_call_return_type = Some(OverloadCallReturnType::Ambiguous);
1397+
}
1398+
}
1399+
12951400
fn as_result(&self) -> Result<(), CallErrorKind> {
12961401
if !self.is_callable() {
12971402
return Err(CallErrorKind::NotCallable);
@@ -1366,8 +1471,11 @@ impl<'db> CallableBinding<'db> {
13661471
/// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
13671472
/// make any useful conclusions about which overload was intended to be called.
13681473
pub(crate) fn return_type(&self) -> Type<'db> {
1369-
if let Some(return_type) = self.return_type {
1370-
return return_type;
1474+
if let Some(overload_call_return_type) = self.overload_call_return_type {
1475+
return match overload_call_return_type {
1476+
OverloadCallReturnType::ArgumentTypeExpansion(return_type) => return_type,
1477+
OverloadCallReturnType::Ambiguous => Type::any(),
1478+
};
13711479
}
13721480
if let Some((_, first_overload)) = self.matching_overloads().next() {
13731481
return first_overload.return_type();
@@ -1410,6 +1518,10 @@ impl<'db> CallableBinding<'db> {
14101518
return;
14111519
}
14121520

1521+
if self.overload_call_return_type.is_some() {
1522+
return;
1523+
}
1524+
14131525
match self.overloads.as_slice() {
14141526
[] => {}
14151527
[overload] => {
@@ -1517,6 +1629,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
15171629
}
15181630
}
15191631

1632+
#[derive(Debug, Copy, Clone)]
1633+
enum OverloadCallReturnType<'db> {
1634+
ArgumentTypeExpansion(Type<'db>),
1635+
Ambiguous,
1636+
}
1637+
15201638
#[derive(Debug)]
15211639
enum MatchingOverloadIndex {
15221640
/// No matching overloads found.
@@ -1851,6 +1969,10 @@ impl<'db> Binding<'db> {
18511969
.map(|(arg_and_type, _)| arg_and_type)
18521970
}
18531971

1972+
fn mark_as_unmatched_overload(&mut self) {
1973+
self.errors.push(BindingError::UnmatchedOverload);
1974+
}
1975+
18541976
fn report_diagnostics(
18551977
&self,
18561978
context: &InferContext<'db, '_>,
@@ -2136,6 +2258,8 @@ pub(crate) enum BindingError<'db> {
21362258
/// We use this variant to report errors in `property.__get__` and `property.__set__`, which
21372259
/// can occur when the call to the underlying getter/setter fails.
21382260
InternalCallError(&'static str),
2261+
/// This overload of the callable does not match the arguments.
2262+
UnmatchedOverload,
21392263
}
21402264

21412265
impl<'db> BindingError<'db> {
@@ -2328,6 +2452,8 @@ impl<'db> BindingError<'db> {
23282452
}
23292453
}
23302454
}
2455+
2456+
Self::UnmatchedOverload => {}
23312457
}
23322458
}
23332459

0 commit comments

Comments
 (0)