Skip to content

Commit 29832ec

Browse files
committed
[ty] Re-try argument matching for argument type expansion
1 parent 5518c84 commit 29832ec

File tree

4 files changed

+147
-18
lines changed

4 files changed

+147
-18
lines changed

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -587,7 +587,7 @@ def _(t: tuple[int, str, int]) -> None:
587587
f(*t)
588588

589589
def _(t: tuple[int, str] | tuple[int, str, int]) -> None:
590-
# TODO: error: [invalid-argument-type]
590+
# error: [no-matching-overload]
591591
f(*t)
592592
```
593593

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,8 @@ def _(a: int | None):
889889
)
890890
```
891891

892+
### Retry from parameter matching
893+
892894
## Filtering based on `Any` / `Unknown`
893895

894896
This is the step 5 of the overload call evaluation algorithm which specifies that:

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,22 @@ impl<'a, 'db> CallArguments<'a, 'db> {
202202
for subtype in &expanded_types {
203203
let mut new_expanded_types = pre_expanded_types.to_vec();
204204
new_expanded_types[index] = Some(*subtype);
205-
expanded_arguments.push(CallArguments::new(
206-
self.arguments.clone(),
207-
new_expanded_types,
208-
));
205+
206+
// Update the arguments list to handle variadic argument expansion
207+
let mut new_arguments = self.arguments.clone();
208+
if let Argument::Variadic(_) = self.arguments[index] {
209+
// If the expanded type is a tuple, update the TupleLength
210+
if let Some(expanded_type) = new_expanded_types[index] {
211+
let length = expanded_type
212+
.try_iterate(db)
213+
.map(|tuple| tuple.len())
214+
.unwrap_or(TupleLength::unknown());
215+
new_arguments[index] = Argument::Variadic(length);
216+
}
217+
}
218+
219+
expanded_arguments
220+
.push(CallArguments::new(new_arguments, new_expanded_types));
209221
}
210222
}
211223

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 128 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ impl<'db> Bindings<'db> {
131131
argument_types: &CallArguments<'_, 'db>,
132132
) -> Result<Self, CallError<'db>> {
133133
for element in &mut self.elements {
134-
element.check_types(db, argument_types);
134+
if let Some((global_argument_forms, global_conflicting_forms)) =
135+
element.check_types(db, argument_types)
136+
{
137+
// If this element returned global forms (indicating successful argument type expansion),
138+
// update the Bindings with these forms
139+
self.argument_forms = global_argument_forms.into();
140+
self.conflicting_forms = global_conflicting_forms.into();
141+
}
135142
}
136143

137144
self.evaluate_known_cases(db);
@@ -1261,7 +1268,11 @@ impl<'db> CallableBinding<'db> {
12611268
}
12621269
}
12631270

1264-
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArguments<'_, 'db>) {
1271+
fn check_types(
1272+
&mut self,
1273+
db: &'db dyn Db,
1274+
argument_types: &CallArguments<'_, 'db>,
1275+
) -> Option<(Vec<Option<ParameterForm>>, Vec<bool>)> {
12651276
// If this callable is a bound method, prepend the self instance onto the arguments list
12661277
// before checking.
12671278
let argument_types = argument_types.with_self(self.bound_type);
@@ -1275,14 +1286,14 @@ impl<'db> CallableBinding<'db> {
12751286
if let [overload] = self.overloads.as_mut_slice() {
12761287
overload.check_types(db, argument_types.as_ref());
12771288
}
1278-
return;
1289+
return None;
12791290
}
12801291
MatchingOverloadIndex::Single(index) => {
12811292
// If only one candidate overload remains, it is the winning match. Evaluate it as
12821293
// a regular (non-overloaded) call.
12831294
self.matching_overload_index = Some(index);
12841295
self.overloads[index].check_types(db, argument_types.as_ref());
1285-
return;
1296+
return None;
12861297
}
12871298
MatchingOverloadIndex::Multiple(indexes) => {
12881299
// If two or more candidate overloads remain, proceed to step 2.
@@ -1308,7 +1319,7 @@ impl<'db> CallableBinding<'db> {
13081319
}
13091320
MatchingOverloadIndex::Single(_) => {
13101321
// If only one overload evaluates without error, it is the winning match.
1311-
return;
1322+
return None;
13121323
}
13131324
MatchingOverloadIndex::Multiple(indexes) => {
13141325
// If two or more candidate overloads remain, proceed to step 4.
@@ -1318,7 +1329,7 @@ impl<'db> CallableBinding<'db> {
13181329
self.filter_overloads_using_any_or_unknown(db, argument_types.as_ref(), &indexes);
13191330

13201331
// We're returning here because this shouldn't lead to argument type expansion.
1321-
return;
1332+
return None;
13221333
}
13231334
}
13241335

@@ -1328,7 +1339,7 @@ impl<'db> CallableBinding<'db> {
13281339

13291340
if expansions.peek().is_none() {
13301341
// Return early if there are no argument types to expand.
1331-
return;
1342+
return None;
13321343
}
13331344

13341345
// State of the bindings _after_ evaluating (type checking) the matching overloads using
@@ -1377,18 +1388,24 @@ impl<'db> CallableBinding<'db> {
13771388
argument_type.display(db)
13781389
);
13791390
snapshotter.restore(self, post_evaluation_snapshot);
1380-
return;
1391+
return None;
13811392
}
13821393
}
13831394

1395+
tracing::debug!("Performing argument type expansion");
1396+
1397+
// Global argument_forms and conflicting_forms that will be merged across all expansions
1398+
let mut global_argument_forms: Vec<Option<ParameterForm>> = Vec::new();
1399+
let mut global_conflicting_forms: Vec<bool> = Vec::new();
1400+
13841401
for expansion in expansions {
13851402
let expanded_argument_lists = match expansion {
13861403
Expansion::LimitReached(index) => {
13871404
snapshotter.restore(self, post_evaluation_snapshot);
13881405
self.overload_call_return_type = Some(
13891406
OverloadCallReturnType::ArgumentTypeExpansionLimitReached(index),
13901407
);
1391-
return;
1408+
return Some((global_argument_forms, global_conflicting_forms));
13921409
}
13931410
Expansion::Expanded(argument_lists) => argument_lists,
13941411
};
@@ -1400,11 +1417,41 @@ impl<'db> CallableBinding<'db> {
14001417

14011418
let mut return_types = Vec::new();
14021419

1403-
for expanded_argument_types in &expanded_argument_lists {
1420+
for expanded_arguments in &expanded_argument_lists {
14041421
let pre_evaluation_snapshot = snapshotter.take(self);
14051422

1423+
// Clear the state of all overloads before re-evaluating from step 1
1424+
for overload in &mut self.overloads {
1425+
overload.reset();
1426+
}
1427+
1428+
let mut argument_forms = vec![None; expanded_arguments.len()];
1429+
let mut conflicting_forms = vec![false; expanded_arguments.len()];
1430+
1431+
// The spec mentions that each expanded argument list should re-evaluate from step
1432+
// 2 which is the type checking step but we're re-evaluating from step 1. The tldr
1433+
// is that it allows ty to match the correct overload in case a variadic argument
1434+
// would expand into different number of arguments with each expansion. Refer to
1435+
// https://github.com/astral-sh/ty/issues/735 for more details.
1436+
for overload in &mut self.overloads {
1437+
overload.match_parameters(
1438+
db,
1439+
expanded_arguments,
1440+
&mut argument_forms,
1441+
&mut conflicting_forms,
1442+
);
1443+
}
1444+
1445+
// Merge argument_forms and conflicting_forms into global ones
1446+
Self::merge_argument_forms(
1447+
&mut global_argument_forms,
1448+
&mut global_conflicting_forms,
1449+
&argument_forms,
1450+
&conflicting_forms,
1451+
);
1452+
14061453
for (_, overload) in self.matching_overloads_mut() {
1407-
overload.check_types(db, expanded_argument_types);
1454+
overload.check_types(db, expanded_arguments);
14081455
}
14091456

14101457
let return_type = match self.matching_overload_index() {
@@ -1417,14 +1464,26 @@ impl<'db> CallableBinding<'db> {
14171464

14181465
self.filter_overloads_using_any_or_unknown(
14191466
db,
1420-
expanded_argument_types,
1467+
expanded_arguments,
14211468
&matching_overload_indexes,
14221469
);
14231470

14241471
Some(self.return_type())
14251472
}
14261473
};
14271474

1475+
tracing::debug!(
1476+
"Return type after evaluating expanded arguments (`{}`): {}",
1477+
expanded_arguments
1478+
.iter_types()
1479+
.map(|arg| arg.display(db).to_string())
1480+
.collect::<Vec<_>>()
1481+
.join(", "),
1482+
return_type
1483+
.map(|ty| ty.display(db).to_string())
1484+
.unwrap_or_else(|| "None".to_string())
1485+
);
1486+
14281487
// This split between initializing and updating the merged evaluation state is
14291488
// required because otherwise it's difficult to differentiate between the
14301489
// following:
@@ -1468,7 +1527,7 @@ impl<'db> CallableBinding<'db> {
14681527
UnionType::from_elements(db, return_types),
14691528
));
14701529

1471-
return;
1530+
return Some((global_argument_forms, global_conflicting_forms));
14721531
}
14731532
}
14741533

@@ -1477,6 +1536,44 @@ impl<'db> CallableBinding<'db> {
14771536
// argument types. This is necessary because we restore the state to the pre-evaluation
14781537
// snapshot when processing the expanded argument lists.
14791538
snapshotter.restore(self, post_evaluation_snapshot);
1539+
None
1540+
}
1541+
1542+
/// Merge argument forms and conflicting forms into global ones.
1543+
fn merge_argument_forms(
1544+
global_argument_forms: &mut Vec<Option<ParameterForm>>,
1545+
global_conflicting_forms: &mut Vec<bool>,
1546+
local_argument_forms: &[Option<ParameterForm>],
1547+
local_conflicting_forms: &[bool],
1548+
) {
1549+
// Resize global lists to match local length if needed
1550+
if global_argument_forms.len() < local_argument_forms.len() {
1551+
global_argument_forms.resize(local_argument_forms.len(), None);
1552+
global_conflicting_forms.resize(local_conflicting_forms.len(), false);
1553+
}
1554+
1555+
// Merge argument forms and conflicting forms
1556+
for (i, (local_form, local_conflict)) in local_argument_forms
1557+
.iter()
1558+
.zip(local_conflicting_forms.iter())
1559+
.enumerate()
1560+
{
1561+
// Update global argument form
1562+
if let Some(global_form) = &mut global_argument_forms[i] {
1563+
if let Some(local_form) = local_form {
1564+
if *global_form != *local_form {
1565+
// Different parameter forms, mark as conflicting
1566+
global_conflicting_forms[i] = true;
1567+
*global_form = *local_form; // Use the new form
1568+
}
1569+
}
1570+
} else {
1571+
global_argument_forms[i] = *local_form;
1572+
}
1573+
1574+
// Update global conflicting form (true takes precedence)
1575+
global_conflicting_forms[i] |= local_conflict;
1576+
}
14801577
}
14811578

14821579
/// Filter overloads based on [`Any`] or [`Unknown`] argument types.
@@ -2048,6 +2145,14 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
20482145
argument_type: Option<Type<'db>>,
20492146
length: TupleLength,
20502147
) -> Result<(), ()> {
2148+
tracing::debug!("tuple length: {length:?}");
2149+
tracing::debug!(
2150+
"argument type: {}",
2151+
argument_type
2152+
.map(|ty| ty.display(db).to_string())
2153+
.unwrap_or_else(|| "<None>".to_string())
2154+
);
2155+
20512156
let tuple = argument_type.map(|ty| ty.iterate(db));
20522157
let mut argument_types = match tuple.as_ref() {
20532158
Some(tuple) => Either::Left(tuple.all_elements().copied()),
@@ -2597,6 +2702,16 @@ impl<'db> Binding<'db> {
25972702
pub(crate) fn errors(&self) -> &[BindingError<'db>] {
25982703
&self.errors
25992704
}
2705+
2706+
/// Resets the state of this binding to its initial state.
2707+
fn reset(&mut self) {
2708+
self.return_ty = Type::unknown();
2709+
self.specialization = None;
2710+
self.inherited_specialization = None;
2711+
self.argument_matches = Box::from([]);
2712+
self.parameter_tys = Box::from([]);
2713+
self.errors.clear();
2714+
}
26002715
}
26012716

26022717
#[derive(Clone, Debug)]

0 commit comments

Comments
 (0)