@@ -1025,7 +1025,7 @@ impl<'db> From<Binding<'db>> for Bindings<'db> {
10251025 signature_type,
10261026 dunder_call_is_possibly_unbound : false ,
10271027 bound_type : None ,
1028- return_type : None ,
1028+ overload_call_return_type : None ,
10291029 overloads : smallvec ! [ from] ,
10301030 } ;
10311031 Bindings {
@@ -1070,7 +1070,7 @@ pub(crate) struct CallableBinding<'db> {
10701070 /// performed, and one of the expansion evaluated successfully for all of the argument lists.
10711071 /// This type is then the union of all the return types of the matched overloads for the
10721072 /// expanded argument lists.
1073- return_type : Option < Type < ' db > > ,
1073+ overload_call_return_type : Option < OverloadCallReturnType < ' db > > ,
10741074
10751075 /// The bindings of each overload of this callable. Will be empty if the type is not callable.
10761076 ///
@@ -1093,7 +1093,7 @@ impl<'db> CallableBinding<'db> {
10931093 signature_type,
10941094 dunder_call_is_possibly_unbound : false ,
10951095 bound_type : None ,
1096- return_type : None ,
1096+ overload_call_return_type : None ,
10971097 overloads,
10981098 }
10991099 }
@@ -1104,7 +1104,7 @@ impl<'db> CallableBinding<'db> {
11041104 signature_type,
11051105 dunder_call_is_possibly_unbound : false ,
11061106 bound_type : None ,
1107- return_type : None ,
1107+ overload_call_return_type : None ,
11081108 overloads : smallvec ! [ ] ,
11091109 }
11101110 }
@@ -1192,9 +1192,18 @@ impl<'db> CallableBinding<'db> {
11921192 // If only one overload evaluates without error, it is the winning match.
11931193 return ;
11941194 }
1195- MatchingOverloadIndex :: Multiple ( _ ) => {
1195+ MatchingOverloadIndex :: Multiple ( indexes ) => {
11961196 // If two or more candidate overloads remain, proceed to step 4.
1197- // TODO: Step 4 and Step 5 goes here...
1197+ tracing:: info!(
1198+ "Multiple overloads match: {:?}, filtering based on Any" ,
1199+ indexes
1200+ ) ;
1201+
1202+ // TODO: Step 4
1203+
1204+ // Step 5
1205+ self . filter_overloads_using_any_or_unknown ( db, argument_types. types ( ) , & indexes) ;
1206+
11981207 // We're returning here because this shouldn't lead to argument type expansion.
11991208 return ;
12001209 }
@@ -1273,7 +1282,10 @@ impl<'db> CallableBinding<'db> {
12731282 // If the number of return types is equal to the number of expanded argument lists,
12741283 // they all evaluated successfully. So, we need to combine their return types by
12751284 // union to determine the final return type.
1276- self . return_type = Some ( UnionType :: from_elements ( db, return_types) ) ;
1285+ self . overload_call_return_type =
1286+ Some ( OverloadCallReturnType :: ArgumentTypeExpansion (
1287+ UnionType :: from_elements ( db, return_types) ,
1288+ ) ) ;
12771289
12781290 // Restore the bindings state to the one that merges the bindings state evaluating
12791291 // each of the expanded argument list.
@@ -1292,6 +1304,99 @@ impl<'db> CallableBinding<'db> {
12921304 snapshotter. restore ( self , post_evaluation_snapshot) ;
12931305 }
12941306
1307+ /// Filter overloads based on [`Any`] or [`Unknown`] argument types.
1308+ ///
1309+ /// This is the step 5 of the [overload call evaluation algorithm][1].
1310+ ///
1311+ /// The filtering works on the remaining overloads that are present at the
1312+ /// `matching_overload_indexes` and are filtered out by marking them as unmatched overloads
1313+ /// using the [`mark_as_unmatched_overload`] method.
1314+ ///
1315+ /// [`Any`]: crate::types::DynamicType::Any
1316+ /// [`Unknown`]: crate::types::DynamicType::Unknown
1317+ /// [`mark_as_unmatched_overload`]: Binding::mark_as_unmatched_overload
1318+ /// [1]: https://typing.python.org/en/latest/spec/overload.html#overload-call-evaluation
1319+ fn filter_overloads_using_any_or_unknown (
1320+ & mut self ,
1321+ db : & ' db dyn Db ,
1322+ argument_types : & [ Type < ' db > ] ,
1323+ matching_overload_indexes : & [ usize ] ,
1324+ ) {
1325+ let top_materialized_argument_type = TupleType :: from_elements (
1326+ db,
1327+ argument_types. iter ( ) . map ( |argument_type| {
1328+ argument_type. top_materialization ( db, TypeVarVariance :: Covariant )
1329+ } ) ,
1330+ ) ;
1331+
1332+ // A flag to indicate whether we've found the overload that makes the remaining overloads
1333+ // unmatched for the given argument types.
1334+ let mut filter_remaining_overloads = false ;
1335+
1336+ for ( upto, current_index) in matching_overload_indexes. iter ( ) . enumerate ( ) {
1337+ if filter_remaining_overloads {
1338+ self . overloads [ * current_index] . mark_as_unmatched_overload ( ) ;
1339+ continue ;
1340+ }
1341+ let mut unions = Vec :: with_capacity ( argument_types. len ( ) ) ;
1342+ for argument_index in 0 ..argument_types. len ( ) {
1343+ let mut union = vec ! [ ] ;
1344+ for overload_index in & matching_overload_indexes[ ..=upto] {
1345+ let overload = & self . overloads [ * overload_index] ;
1346+ let Some ( parameter_index) = overload. argument_parameters [ argument_index] else {
1347+ // There is no parameter for this argument in this overload.
1348+ continue ;
1349+ } ;
1350+ union. push (
1351+ overload. signature . parameters ( ) [ parameter_index]
1352+ . annotated_type ( )
1353+ . unwrap_or ( Type :: unknown ( ) ) ,
1354+ ) ;
1355+ }
1356+ if union. is_empty ( ) {
1357+ continue ;
1358+ }
1359+ unions. push ( UnionType :: from_elements ( db, union) ) ;
1360+ }
1361+ if unions. len ( ) != argument_types. len ( ) {
1362+ continue ;
1363+ }
1364+ if top_materialized_argument_type
1365+ . is_assignable_to ( db, TupleType :: from_elements ( db, unions) )
1366+ {
1367+ filter_remaining_overloads = true ;
1368+ }
1369+ }
1370+
1371+ // Once this filtering process is applied for all arguments, examine the return types of
1372+ // the remaining overloads. If the resulting return types for all remaining overloads are
1373+ // equivalent, proceed to step 6.
1374+ let are_return_types_equivalent_for_all_matching_overloads = {
1375+ let mut matching_overloads = self . matching_overloads ( ) ;
1376+ if let Some ( first_overload_return_type) = matching_overloads
1377+ . next ( )
1378+ . map ( |( _, overload) | overload. return_type ( ) )
1379+ {
1380+ matching_overloads. all ( |( _, overload) | {
1381+ overload
1382+ . return_type ( )
1383+ . is_equivalent_to ( db, first_overload_return_type)
1384+ } )
1385+ } else {
1386+ // No matching overload
1387+ true
1388+ }
1389+ } ;
1390+
1391+ if !are_return_types_equivalent_for_all_matching_overloads {
1392+ // Overload matching is ambiguous.
1393+ for ( _, overload) in self . matching_overloads_mut ( ) {
1394+ overload. mark_as_unmatched_overload ( ) ;
1395+ }
1396+ self . overload_call_return_type = Some ( OverloadCallReturnType :: Ambiguous ) ;
1397+ }
1398+ }
1399+
12951400 fn as_result ( & self ) -> Result < ( ) , CallErrorKind > {
12961401 if !self . is_callable ( ) {
12971402 return Err ( CallErrorKind :: NotCallable ) ;
@@ -1366,8 +1471,11 @@ impl<'db> CallableBinding<'db> {
13661471 /// For an invalid call to an overloaded function, we return `Type::unknown`, since we cannot
13671472 /// make any useful conclusions about which overload was intended to be called.
13681473 pub ( crate ) fn return_type ( & self ) -> Type < ' db > {
1369- if let Some ( return_type) = self . return_type {
1370- return return_type;
1474+ if let Some ( overload_call_return_type) = self . overload_call_return_type {
1475+ return match overload_call_return_type {
1476+ OverloadCallReturnType :: ArgumentTypeExpansion ( return_type) => return_type,
1477+ OverloadCallReturnType :: Ambiguous => Type :: any ( ) ,
1478+ } ;
13711479 }
13721480 if let Some ( ( _, first_overload) ) = self . matching_overloads ( ) . next ( ) {
13731481 return first_overload. return_type ( ) ;
@@ -1410,6 +1518,10 @@ impl<'db> CallableBinding<'db> {
14101518 return ;
14111519 }
14121520
1521+ if self . overload_call_return_type . is_some ( ) {
1522+ return ;
1523+ }
1524+
14131525 match self . overloads . as_slice ( ) {
14141526 [ ] => { }
14151527 [ overload] => {
@@ -1517,6 +1629,12 @@ impl<'a, 'db> IntoIterator for &'a CallableBinding<'db> {
15171629 }
15181630}
15191631
1632+ #[ derive( Debug , Copy , Clone ) ]
1633+ enum OverloadCallReturnType < ' db > {
1634+ ArgumentTypeExpansion ( Type < ' db > ) ,
1635+ Ambiguous ,
1636+ }
1637+
15201638#[ derive( Debug ) ]
15211639enum MatchingOverloadIndex {
15221640 /// No matching overloads found.
@@ -1851,6 +1969,10 @@ impl<'db> Binding<'db> {
18511969 . map ( |( arg_and_type, _) | arg_and_type)
18521970 }
18531971
1972+ fn mark_as_unmatched_overload ( & mut self ) {
1973+ self . errors . push ( BindingError :: UnmatchedOverload ) ;
1974+ }
1975+
18541976 fn report_diagnostics (
18551977 & self ,
18561978 context : & InferContext < ' db , ' _ > ,
@@ -2136,6 +2258,8 @@ pub(crate) enum BindingError<'db> {
21362258 /// We use this variant to report errors in `property.__get__` and `property.__set__`, which
21372259 /// can occur when the call to the underlying getter/setter fails.
21382260 InternalCallError ( & ' static str ) ,
2261+ /// This overload of the callable does not match the arguments.
2262+ UnmatchedOverload ,
21392263}
21402264
21412265impl < ' db > BindingError < ' db > {
@@ -2328,6 +2452,8 @@ impl<'db> BindingError<'db> {
23282452 }
23292453 }
23302454 }
2455+
2456+ Self :: UnmatchedOverload => { }
23312457 }
23322458 }
23332459
0 commit comments