Skip to content

Commit 4e94b22

Browse files
authored
[ty] Support single-starred argument for overload call (#20223)
## Summary closes: astral-sh/ty#247 This PR adds support for variadic arguments to overload call evaluation. This basically boils down to making sure that the overloads are not filtered out incorrectly during the step 5 in the overload call evaluation algorithm. For context, the step 5 tries to filter out the remaining overloads after finding an overload where the materialization of argument types are assignable to the parameter types. The issue with the previous implementation was that it wouldn't unpack the variadic argument and wouldn't consider the many-to-one (multiple arguments mapping to a single variadic parameter) correctly. This PR fixes that. ## Test Plan Update existing test cases and resolve the TODOs.
1 parent 0639da2 commit 4e94b22

File tree

2 files changed

+99
-71
lines changed

2 files changed

+99
-71
lines changed

crates/ty_python_semantic/resources/mdtest/call/overloads.md

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ reveal_type(f(A())) # revealed: A
139139
reveal_type(f(*(A(),))) # revealed: A
140140

141141
reveal_type(f(B())) # revealed: A
142-
# TODO: revealed: A
143-
reveal_type(f(*(B(),))) # revealed: Unknown
142+
reveal_type(f(*(B(),))) # revealed: A
144143

145144
# But, in this case, the arity check filters out the first overload, so we only have one match:
146145
reveal_type(f(B(), 1)) # revealed: B
@@ -551,16 +550,13 @@ from overloaded import MyEnumSubclass, ActualEnum, f
551550

552551
def _(actual_enum: ActualEnum, my_enum_instance: MyEnumSubclass):
553552
reveal_type(f(actual_enum)) # revealed: Both
554-
# TODO: revealed: Both
555-
reveal_type(f(*(actual_enum,))) # revealed: Unknown
553+
reveal_type(f(*(actual_enum,))) # revealed: Both
556554

557555
reveal_type(f(ActualEnum.A)) # revealed: OnlyA
558-
# TODO: revealed: OnlyA
559-
reveal_type(f(*(ActualEnum.A,))) # revealed: Unknown
556+
reveal_type(f(*(ActualEnum.A,))) # revealed: OnlyA
560557

561558
reveal_type(f(ActualEnum.B)) # revealed: OnlyB
562-
# TODO: revealed: OnlyB
563-
reveal_type(f(*(ActualEnum.B,))) # revealed: Unknown
559+
reveal_type(f(*(ActualEnum.B,))) # revealed: OnlyB
564560

565561
reveal_type(f(my_enum_instance)) # revealed: MyEnumSubclass
566562
reveal_type(f(*(my_enum_instance,))) # revealed: MyEnumSubclass
@@ -1097,12 +1093,10 @@ reveal_type(f(*(1,))) # revealed: str
10971093

10981094
def _(list_int: list[int], list_any: list[Any]):
10991095
reveal_type(f(list_int)) # revealed: int
1100-
# TODO: revealed: int
1101-
reveal_type(f(*(list_int,))) # revealed: Unknown
1096+
reveal_type(f(*(list_int,))) # revealed: int
11021097

11031098
reveal_type(f(list_any)) # revealed: int
1104-
# TODO: revealed: int
1105-
reveal_type(f(*(list_any,))) # revealed: Unknown
1099+
reveal_type(f(*(list_any,))) # revealed: int
11061100
```
11071101

11081102
### Single list argument (ambiguous)
@@ -1136,8 +1130,7 @@ def _(list_int: list[int], list_any: list[Any]):
11361130
# All materializations of `list[int]` are assignable to `list[int]`, so it matches the first
11371131
# overload.
11381132
reveal_type(f(list_int)) # revealed: int
1139-
# TODO: revealed: int
1140-
reveal_type(f(*(list_int,))) # revealed: Unknown
1133+
reveal_type(f(*(list_int,))) # revealed: int
11411134

11421135
# All materializations of `list[Any]` are assignable to `list[int]` and `list[Any]`, but the
11431136
# return type of first and second overloads are not equivalent, so the overload matching
@@ -1170,25 +1163,21 @@ reveal_type(f("a")) # revealed: str
11701163
reveal_type(f(*("a",))) # revealed: str
11711164

11721165
reveal_type(f((1, "b"))) # revealed: int
1173-
# TODO: revealed: int
1174-
reveal_type(f(*((1, "b"),))) # revealed: Unknown
1166+
reveal_type(f(*((1, "b"),))) # revealed: int
11751167

11761168
reveal_type(f((1, 2))) # revealed: int
1177-
# TODO: revealed: int
1178-
reveal_type(f(*((1, 2),))) # revealed: Unknown
1169+
reveal_type(f(*((1, 2),))) # revealed: int
11791170

11801171
def _(int_str: tuple[int, str], int_any: tuple[int, Any], any_any: tuple[Any, Any]):
11811172
# All materializations are assignable to first overload, so second and third overloads are
11821173
# eliminated
11831174
reveal_type(f(int_str)) # revealed: int
1184-
# TODO: revealed: int
1185-
reveal_type(f(*(int_str,))) # revealed: Unknown
1175+
reveal_type(f(*(int_str,))) # revealed: int
11861176

11871177
# All materializations are assignable to second overload, so the third overload is eliminated;
11881178
# the return type of first and second overload is equivalent
11891179
reveal_type(f(int_any)) # revealed: int
1190-
# TODO: revealed: int
1191-
reveal_type(f(*(int_any,))) # revealed: Unknown
1180+
reveal_type(f(*(int_any,))) # revealed: int
11921181

11931182
# All materializations of `tuple[Any, Any]` are assignable to the parameters of all the
11941183
# overloads, but the return types aren't equivalent, so the overload matching is ambiguous
@@ -1266,26 +1255,22 @@ def _(list_int: list[int], list_any: list[Any], int_str: tuple[int, str], int_an
12661255
# All materializations of both argument types are assignable to the first overload, so the
12671256
# second and third overloads are filtered out
12681257
reveal_type(f(list_int, int_str)) # revealed: A
1269-
# TODO: revealed: A
1270-
reveal_type(f(*(list_int, int_str))) # revealed: Unknown
1258+
reveal_type(f(*(list_int, int_str))) # revealed: A
12711259

12721260
# All materialization of first argument is assignable to first overload and for the second
12731261
# argument, they're assignable to the second overload, so the third overload is filtered out
12741262
reveal_type(f(list_int, int_any)) # revealed: A
1275-
# TODO: revealed: A
1276-
reveal_type(f(*(list_int, int_any))) # revealed: Unknown
1263+
reveal_type(f(*(list_int, int_any))) # revealed: A
12771264

12781265
# All materialization of first argument is assignable to second overload and for the second
12791266
# argument, they're assignable to the first overload, so the third overload is filtered out
12801267
reveal_type(f(list_any, int_str)) # revealed: A
1281-
# TODO: revealed: A
1282-
reveal_type(f(*(list_any, int_str))) # revealed: Unknown
1268+
reveal_type(f(*(list_any, int_str))) # revealed: A
12831269

12841270
# All materializations of both arguments are assignable to the second overload, so the third
12851271
# overload is filtered out
12861272
reveal_type(f(list_any, int_any)) # revealed: A
1287-
# TODO: revealed: A
1288-
reveal_type(f(*(list_any, int_any))) # revealed: Unknown
1273+
reveal_type(f(*(list_any, int_any))) # revealed: A
12891274

12901275
# All materializations of first argument is assignable to the second overload and for the second
12911276
# argument, they're assignable to the third overload, so no overloads are filtered out; the
@@ -1316,8 +1301,7 @@ from overloaded import f
13161301

13171302
def _(literal: LiteralString, string: str, any: Any):
13181303
reveal_type(f(literal)) # revealed: LiteralString
1319-
# TODO: revealed: LiteralString
1320-
reveal_type(f(*(literal,))) # revealed: Unknown
1304+
reveal_type(f(*(literal,))) # revealed: LiteralString
13211305

13221306
reveal_type(f(string)) # revealed: str
13231307
reveal_type(f(*(string,))) # revealed: str
@@ -1355,12 +1339,10 @@ from overloaded import f
13551339

13561340
def _(list_int: list[int], list_str: list[str], list_any: list[Any], any: Any):
13571341
reveal_type(f(list_int)) # revealed: A
1358-
# TODO: revealed: A
1359-
reveal_type(f(*(list_int,))) # revealed: Unknown
1342+
reveal_type(f(*(list_int,))) # revealed: A
13601343

13611344
reveal_type(f(list_str)) # revealed: str
1362-
# TODO: Should be `str`
1363-
reveal_type(f(*(list_str,))) # revealed: Unknown
1345+
reveal_type(f(*(list_str,))) # revealed: str
13641346

13651347
reveal_type(f(list_any)) # revealed: Unknown
13661348
reveal_type(f(*(list_any,))) # revealed: Unknown
@@ -1561,12 +1543,10 @@ def _(any: Any):
15611543
reveal_type(f(*(any,), flag=False)) # revealed: str
15621544

15631545
def _(args: tuple[Any, Literal[True]]):
1564-
# TODO: revealed: int
1565-
reveal_type(f(*args)) # revealed: Unknown
1546+
reveal_type(f(*args)) # revealed: int
15661547

15671548
def _(args: tuple[Any, Literal[False]]):
1568-
# TODO: revealed: str
1569-
reveal_type(f(*args)) # revealed: Unknown
1549+
reveal_type(f(*args)) # revealed: str
15701550
```
15711551

15721552
### Argument type expansion

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 79 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use crate::types::tuple::{TupleLength, TupleType};
3232
use crate::types::{
3333
BoundMethodType, ClassLiteral, DataclassParams, FieldInstance, KnownBoundMethodType,
3434
KnownClass, KnownInstanceType, MemberLookupPolicy, PropertyInstanceType, SpecialFormType,
35-
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionType,
35+
TrackedConstraintSet, TypeAliasType, TypeContext, TypeMapping, UnionBuilder, UnionType,
3636
WrapperDescriptorKind, enums, ide_support, todo_type,
3737
};
3838
use ruff_db::diagnostic::{Annotation, Diagnostic, SubDiagnostic, SubDiagnosticSeverity};
@@ -1588,48 +1588,82 @@ impl<'db> CallableBinding<'db> {
15881588
arguments: &CallArguments<'_, 'db>,
15891589
matching_overload_indexes: &[usize],
15901590
) {
1591+
// The maximum number of parameters across all the overloads that are being considered
1592+
// for filtering.
1593+
let max_parameter_count = matching_overload_indexes
1594+
.iter()
1595+
.map(|&index| self.overloads[index].signature.parameters().len())
1596+
.max()
1597+
.unwrap_or(0);
1598+
15911599
// These are the parameter indexes that matches the arguments that participate in the
15921600
// filtering process.
15931601
//
15941602
// The parameter types at these indexes have at least one overload where the type isn't
15951603
// gradual equivalent to the parameter types at the same index for other overloads.
15961604
let mut participating_parameter_indexes = HashSet::new();
15971605

1598-
// These only contain the top materialized argument types for the corresponding
1599-
// participating parameter indexes.
1600-
let mut top_materialized_argument_types = vec![];
1601-
1602-
for (argument_index, argument_type) in arguments.iter_types().enumerate() {
1603-
let mut first_parameter_type: Option<Type<'db>> = None;
1604-
let mut participating_parameter_index = None;
1606+
// The parameter types at each index for the first overload containing a parameter at
1607+
// that index.
1608+
let mut first_parameter_types: Vec<Option<Type<'db>>> = vec![None; max_parameter_count];
16051609

1606-
'overload: for overload_index in matching_overload_indexes {
1610+
for argument_index in 0..arguments.len() {
1611+
for overload_index in matching_overload_indexes {
16071612
let overload = &self.overloads[*overload_index];
1608-
for parameter_index in &overload.argument_matches[argument_index].parameters {
1613+
for &parameter_index in &overload.argument_matches[argument_index].parameters {
16091614
// TODO: For an unannotated `self` / `cls` parameter, the type should be
16101615
// `typing.Self` / `type[typing.Self]`
1611-
let current_parameter_type = overload.signature.parameters()[*parameter_index]
1616+
let current_parameter_type = overload.signature.parameters()[parameter_index]
16121617
.annotated_type()
16131618
.unwrap_or(Type::unknown());
1619+
let first_parameter_type = &mut first_parameter_types[parameter_index];
16141620
if let Some(first_parameter_type) = first_parameter_type {
16151621
if !first_parameter_type.is_equivalent_to(db, current_parameter_type) {
1616-
participating_parameter_index = Some(*parameter_index);
1617-
break 'overload;
1622+
participating_parameter_indexes.insert(parameter_index);
16181623
}
16191624
} else {
1620-
first_parameter_type = Some(current_parameter_type);
1625+
*first_parameter_type = Some(current_parameter_type);
16211626
}
16221627
}
16231628
}
1629+
}
16241630

1625-
if let Some(parameter_index) = participating_parameter_index {
1626-
participating_parameter_indexes.insert(parameter_index);
1627-
top_materialized_argument_types.push(argument_type.top_materialization(db));
1631+
let mut union_argument_type_builders = std::iter::repeat_with(|| UnionBuilder::new(db))
1632+
.take(max_parameter_count)
1633+
.collect::<Vec<_>>();
1634+
1635+
for (argument_index, argument_type) in arguments.iter_types().enumerate() {
1636+
for overload_index in matching_overload_indexes {
1637+
let overload = &self.overloads[*overload_index];
1638+
for (parameter_index, variadic_argument_type) in
1639+
overload.argument_matches[argument_index].iter()
1640+
{
1641+
if !participating_parameter_indexes.contains(&parameter_index) {
1642+
continue;
1643+
}
1644+
union_argument_type_builders[parameter_index].add_in_place(
1645+
variadic_argument_type
1646+
.unwrap_or(argument_type)
1647+
.top_materialization(db),
1648+
);
1649+
}
16281650
}
16291651
}
16301652

1631-
let top_materialized_argument_type =
1632-
Type::heterogeneous_tuple(db, top_materialized_argument_types);
1653+
// These only contain the top materialized argument types for the corresponding
1654+
// participating parameter indexes.
1655+
let top_materialized_argument_type = Type::heterogeneous_tuple(
1656+
db,
1657+
union_argument_type_builders
1658+
.into_iter()
1659+
.filter_map(|builder| {
1660+
if builder.is_empty() {
1661+
None
1662+
} else {
1663+
Some(builder.build())
1664+
}
1665+
}),
1666+
);
16331667

16341668
// A flag to indicate whether we've found the overload that makes the remaining overloads
16351669
// unmatched for the given argument types.
@@ -1640,15 +1674,22 @@ impl<'db> CallableBinding<'db> {
16401674
self.overloads[*current_index].mark_as_unmatched_overload();
16411675
continue;
16421676
}
1643-
let mut parameter_types = Vec::with_capacity(arguments.len());
1677+
1678+
let mut union_parameter_types = std::iter::repeat_with(|| UnionBuilder::new(db))
1679+
.take(max_parameter_count)
1680+
.collect::<Vec<_>>();
1681+
1682+
// The number of parameters that have been skipped because they don't participate in
1683+
// the filtering process. This is used to make sure the types are added to the
1684+
// corresponding parameter index in `union_parameter_types`.
1685+
let mut skipped_parameters = 0;
1686+
16441687
for argument_index in 0..arguments.len() {
1645-
// The parameter types at the current argument index.
1646-
let mut current_parameter_types = vec![];
16471688
for overload_index in &matching_overload_indexes[..=upto] {
16481689
let overload = &self.overloads[*overload_index];
16491690
for parameter_index in &overload.argument_matches[argument_index].parameters {
16501691
if !participating_parameter_indexes.contains(parameter_index) {
1651-
// This parameter doesn't participate in the filtering process.
1692+
skipped_parameters += 1;
16521693
continue;
16531694
}
16541695
// TODO: For an unannotated `self` / `cls` parameter, the type should be
@@ -1664,17 +1705,24 @@ impl<'db> CallableBinding<'db> {
16641705
parameter_type =
16651706
parameter_type.apply_specialization(db, inherited_specialization);
16661707
}
1667-
current_parameter_types.push(parameter_type);
1708+
union_parameter_types[parameter_index.saturating_sub(skipped_parameters)]
1709+
.add_in_place(parameter_type);
16681710
}
16691711
}
1670-
if current_parameter_types.is_empty() {
1671-
continue;
1672-
}
1673-
parameter_types.push(UnionType::from_elements(db, current_parameter_types));
16741712
}
1675-
if top_materialized_argument_type
1676-
.is_assignable_to(db, Type::heterogeneous_tuple(db, parameter_types))
1677-
{
1713+
1714+
let parameter_types = Type::heterogeneous_tuple(
1715+
db,
1716+
union_parameter_types.into_iter().filter_map(|builder| {
1717+
if builder.is_empty() {
1718+
None
1719+
} else {
1720+
Some(builder.build())
1721+
}
1722+
}),
1723+
);
1724+
1725+
if top_materialized_argument_type.is_assignable_to(db, parameter_types) {
16781726
filter_remaining_overloads = true;
16791727
}
16801728
}

0 commit comments

Comments
 (0)