@@ -131,7 +131,14 @@ impl<'db> Bindings<'db> {
131131 argument_types : & CallArguments < ' _ , ' db > ,
132132 ) -> Result < Self , CallError < ' db > > {
133133 for element in & mut self . elements {
134- element. check_types ( db, argument_types) ;
134+ if let Some ( ( global_argument_forms, global_conflicting_forms) ) =
135+ element. check_types ( db, argument_types)
136+ {
137+ // If this element returned global forms (indicating successful argument type expansion),
138+ // update the Bindings with these forms
139+ self . argument_forms = global_argument_forms. into ( ) ;
140+ self . conflicting_forms = global_conflicting_forms. into ( ) ;
141+ }
135142 }
136143
137144 self . evaluate_known_cases ( db) ;
@@ -1261,7 +1268,11 @@ impl<'db> CallableBinding<'db> {
12611268 }
12621269 }
12631270
1264- fn check_types ( & mut self , db : & ' db dyn Db , argument_types : & CallArguments < ' _ , ' db > ) {
1271+ fn check_types (
1272+ & mut self ,
1273+ db : & ' db dyn Db ,
1274+ argument_types : & CallArguments < ' _ , ' db > ,
1275+ ) -> Option < ( Vec < Option < ParameterForm > > , Vec < bool > ) > {
12651276 // If this callable is a bound method, prepend the self instance onto the arguments list
12661277 // before checking.
12671278 let argument_types = argument_types. with_self ( self . bound_type ) ;
@@ -1275,14 +1286,14 @@ impl<'db> CallableBinding<'db> {
12751286 if let [ overload] = self . overloads . as_mut_slice ( ) {
12761287 overload. check_types ( db, argument_types. as_ref ( ) ) ;
12771288 }
1278- return ;
1289+ return None ;
12791290 }
12801291 MatchingOverloadIndex :: Single ( index) => {
12811292 // If only one candidate overload remains, it is the winning match. Evaluate it as
12821293 // a regular (non-overloaded) call.
12831294 self . matching_overload_index = Some ( index) ;
12841295 self . overloads [ index] . check_types ( db, argument_types. as_ref ( ) ) ;
1285- return ;
1296+ return None ;
12861297 }
12871298 MatchingOverloadIndex :: Multiple ( indexes) => {
12881299 // If two or more candidate overloads remain, proceed to step 2.
@@ -1308,7 +1319,7 @@ impl<'db> CallableBinding<'db> {
13081319 }
13091320 MatchingOverloadIndex :: Single ( _) => {
13101321 // If only one overload evaluates without error, it is the winning match.
1311- return ;
1322+ return None ;
13121323 }
13131324 MatchingOverloadIndex :: Multiple ( indexes) => {
13141325 // If two or more candidate overloads remain, proceed to step 4.
@@ -1318,7 +1329,7 @@ impl<'db> CallableBinding<'db> {
13181329 self . filter_overloads_using_any_or_unknown ( db, argument_types. as_ref ( ) , & indexes) ;
13191330
13201331 // We're returning here because this shouldn't lead to argument type expansion.
1321- return ;
1332+ return None ;
13221333 }
13231334 }
13241335
@@ -1328,7 +1339,7 @@ impl<'db> CallableBinding<'db> {
13281339
13291340 if expansions. peek ( ) . is_none ( ) {
13301341 // Return early if there are no argument types to expand.
1331- return ;
1342+ return None ;
13321343 }
13331344
13341345 // State of the bindings _after_ evaluating (type checking) the matching overloads using
@@ -1377,18 +1388,24 @@ impl<'db> CallableBinding<'db> {
13771388 argument_type. display( db)
13781389 ) ;
13791390 snapshotter. restore ( self , post_evaluation_snapshot) ;
1380- return ;
1391+ return None ;
13811392 }
13821393 }
13831394
1395+ tracing:: debug!( "Performing argument type expansion" ) ;
1396+
1397+ // Global argument_forms and conflicting_forms that will be merged across all expansions
1398+ let mut global_argument_forms: Vec < Option < ParameterForm > > = Vec :: new ( ) ;
1399+ let mut global_conflicting_forms: Vec < bool > = Vec :: new ( ) ;
1400+
13841401 for expansion in expansions {
13851402 let expanded_argument_lists = match expansion {
13861403 Expansion :: LimitReached ( index) => {
13871404 snapshotter. restore ( self , post_evaluation_snapshot) ;
13881405 self . overload_call_return_type = Some (
13891406 OverloadCallReturnType :: ArgumentTypeExpansionLimitReached ( index) ,
13901407 ) ;
1391- return ;
1408+ return Some ( ( global_argument_forms , global_conflicting_forms ) ) ;
13921409 }
13931410 Expansion :: Expanded ( argument_lists) => argument_lists,
13941411 } ;
@@ -1400,11 +1417,41 @@ impl<'db> CallableBinding<'db> {
14001417
14011418 let mut return_types = Vec :: new ( ) ;
14021419
1403- for expanded_argument_types in & expanded_argument_lists {
1420+ for expanded_arguments in & expanded_argument_lists {
14041421 let pre_evaluation_snapshot = snapshotter. take ( self ) ;
14051422
1423+ // Clear the state of all overloads before re-evaluating from step 1
1424+ for overload in & mut self . overloads {
1425+ overload. reset ( ) ;
1426+ }
1427+
1428+ let mut argument_forms = vec ! [ None ; expanded_arguments. len( ) ] ;
1429+ let mut conflicting_forms = vec ! [ false ; expanded_arguments. len( ) ] ;
1430+
1431+ // The spec mentions that each expanded argument list should re-evaluate from step
1432+ // 2 which is the type checking step but we're re-evaluating from step 1. The tldr
1433+ // is that it allows ty to match the correct overload in case a variadic argument
1434+ // would expand into different number of arguments with each expansion. Refer to
1435+ // https://github.com/astral-sh/ty/issues/735 for more details.
1436+ for overload in & mut self . overloads {
1437+ overload. match_parameters (
1438+ db,
1439+ expanded_arguments,
1440+ & mut argument_forms,
1441+ & mut conflicting_forms,
1442+ ) ;
1443+ }
1444+
1445+ // Merge argument_forms and conflicting_forms into global ones
1446+ Self :: merge_argument_forms (
1447+ & mut global_argument_forms,
1448+ & mut global_conflicting_forms,
1449+ & argument_forms,
1450+ & conflicting_forms,
1451+ ) ;
1452+
14061453 for ( _, overload) in self . matching_overloads_mut ( ) {
1407- overload. check_types ( db, expanded_argument_types ) ;
1454+ overload. check_types ( db, expanded_arguments ) ;
14081455 }
14091456
14101457 let return_type = match self . matching_overload_index ( ) {
@@ -1417,14 +1464,26 @@ impl<'db> CallableBinding<'db> {
14171464
14181465 self . filter_overloads_using_any_or_unknown (
14191466 db,
1420- expanded_argument_types ,
1467+ expanded_arguments ,
14211468 & matching_overload_indexes,
14221469 ) ;
14231470
14241471 Some ( self . return_type ( ) )
14251472 }
14261473 } ;
14271474
1475+ tracing:: debug!(
1476+ "Return type after evaluating expanded arguments (`{}`): {}" ,
1477+ expanded_arguments
1478+ . iter_types( )
1479+ . map( |arg| arg. display( db) . to_string( ) )
1480+ . collect:: <Vec <_>>( )
1481+ . join( ", " ) ,
1482+ return_type
1483+ . map( |ty| ty. display( db) . to_string( ) )
1484+ . unwrap_or_else( || "None" . to_string( ) )
1485+ ) ;
1486+
14281487 // This split between initializing and updating the merged evaluation state is
14291488 // required because otherwise it's difficult to differentiate between the
14301489 // following:
@@ -1468,7 +1527,7 @@ impl<'db> CallableBinding<'db> {
14681527 UnionType :: from_elements ( db, return_types) ,
14691528 ) ) ;
14701529
1471- return ;
1530+ return Some ( ( global_argument_forms , global_conflicting_forms ) ) ;
14721531 }
14731532 }
14741533
@@ -1477,6 +1536,44 @@ impl<'db> CallableBinding<'db> {
14771536 // argument types. This is necessary because we restore the state to the pre-evaluation
14781537 // snapshot when processing the expanded argument lists.
14791538 snapshotter. restore ( self , post_evaluation_snapshot) ;
1539+ None
1540+ }
1541+
1542+ /// Merge argument forms and conflicting forms into global ones.
1543+ fn merge_argument_forms (
1544+ global_argument_forms : & mut Vec < Option < ParameterForm > > ,
1545+ global_conflicting_forms : & mut Vec < bool > ,
1546+ local_argument_forms : & [ Option < ParameterForm > ] ,
1547+ local_conflicting_forms : & [ bool ] ,
1548+ ) {
1549+ // Resize global lists to match local length if needed
1550+ if global_argument_forms. len ( ) < local_argument_forms. len ( ) {
1551+ global_argument_forms. resize ( local_argument_forms. len ( ) , None ) ;
1552+ global_conflicting_forms. resize ( local_conflicting_forms. len ( ) , false ) ;
1553+ }
1554+
1555+ // Merge argument forms and conflicting forms
1556+ for ( i, ( local_form, local_conflict) ) in local_argument_forms
1557+ . iter ( )
1558+ . zip ( local_conflicting_forms. iter ( ) )
1559+ . enumerate ( )
1560+ {
1561+ // Update global argument form
1562+ if let Some ( global_form) = & mut global_argument_forms[ i] {
1563+ if let Some ( local_form) = local_form {
1564+ if * global_form != * local_form {
1565+ // Different parameter forms, mark as conflicting
1566+ global_conflicting_forms[ i] = true ;
1567+ * global_form = * local_form; // Use the new form
1568+ }
1569+ }
1570+ } else {
1571+ global_argument_forms[ i] = * local_form;
1572+ }
1573+
1574+ // Update global conflicting form (true takes precedence)
1575+ global_conflicting_forms[ i] |= local_conflict;
1576+ }
14801577 }
14811578
14821579 /// Filter overloads based on [`Any`] or [`Unknown`] argument types.
@@ -2048,6 +2145,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
20482145 argument_type : Option < Type < ' db > > ,
20492146 length : TupleLength ,
20502147 ) -> Result < ( ) , ( ) > {
2148+ tracing:: debug!( "tuple length: {length:?}" ) ;
2149+ tracing:: debug!(
2150+ "argument type: {}" ,
2151+ argument_type
2152+ . map( |ty| ty. display( db) . to_string( ) )
2153+ . unwrap_or_else( || "<None>" . to_string( ) )
2154+ ) ;
2155+
20512156 let tuple = argument_type. map ( |ty| ty. iterate ( db) ) ;
20522157 let mut argument_types = match tuple. as_ref ( ) {
20532158 Some ( tuple) => Either :: Left ( tuple. all_elements ( ) . copied ( ) ) ,
@@ -2597,6 +2702,16 @@ impl<'db> Binding<'db> {
25972702 pub ( crate ) fn errors ( & self ) -> & [ BindingError < ' db > ] {
25982703 & self . errors
25992704 }
2705+
2706+ /// Resets the state of this binding to its initial state.
2707+ fn reset ( & mut self ) {
2708+ self . return_ty = Type :: unknown ( ) ;
2709+ self . specialization = None ;
2710+ self . inherited_specialization = None ;
2711+ self . argument_matches = Box :: from ( [ ] ) ;
2712+ self . parameter_tys = Box :: from ( [ ] ) ;
2713+ self . errors . clear ( ) ;
2714+ }
26002715}
26012716
26022717#[ derive( Clone , Debug ) ]
0 commit comments