Skip to content

Commit e1bb74b

Browse files
authored
[ty] Match variadic argument to variadic parameter (#20511)
## Summary Closes: astral-sh/ty#1236 This PR fixes a bug where the variadic argument wouldn't match against the variadic parameter in certain scenarios. This was happening because I didn't realize that the `all_elements` iterator wouldn't keep on returning the variable element (which is correct, I just didn't realize it back then). I don't think we can use the `resize` method here because we don't know how many parameters this variadic argument is matching against as this is where the actual parameter matching occurs. ## Test Plan Expand test cases to consider a few more combinations of arguments and parameters which are variadic.
1 parent edeb458 commit e1bb74b

File tree

6 files changed

+150
-67
lines changed

6 files changed

+150
-67
lines changed

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,96 @@ def f(*args: int) -> int:
642642
reveal_type(f("foo")) # revealed: int
643643
```
644644

645+
### Variadic argument, variadic parameter
646+
647+
```toml
648+
[environment]
649+
python-version = "3.11"
650+
```
651+
652+
```py
653+
def f(*args: int) -> int:
654+
return 1
655+
656+
def _(args: list[str]) -> None:
657+
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `int`, found `str`"
658+
reveal_type(f(*args)) # revealed: int
659+
```
660+
661+
Considering a few different shapes of tuple for the splatted argument:
662+
663+
```py
664+
def f1(*args: str): ...
665+
def _(
666+
args1: tuple[str, ...],
667+
args2: tuple[str, *tuple[str, ...]],
668+
args3: tuple[str, *tuple[str, ...], str],
669+
args4: tuple[int, *tuple[str, ...]],
670+
args5: tuple[int, *tuple[str, ...], str],
671+
args6: tuple[*tuple[str, ...], str],
672+
args7: tuple[*tuple[str, ...], int],
673+
args8: tuple[int, *tuple[str, ...], int],
674+
args9: tuple[str, *tuple[str, ...], int],
675+
args10: tuple[str, *tuple[int, ...], str],
676+
):
677+
f1(*args1)
678+
f1(*args2)
679+
f1(*args3)
680+
f1(*args4) # error: [invalid-argument-type]
681+
f1(*args5) # error: [invalid-argument-type]
682+
f1(*args6)
683+
f1(*args7) # error: [invalid-argument-type]
684+
685+
# The reason for two errors here is because of the two fixed elements in the tuple of `args8`
686+
# which are both `int`
687+
# error: [invalid-argument-type]
688+
# error: [invalid-argument-type]
689+
f1(*args8)
690+
691+
f1(*args9) # error: [invalid-argument-type]
692+
f1(*args10) # error: [invalid-argument-type]
693+
```
694+
695+
### Mixed argument and parameter containing variadic
696+
697+
```toml
698+
[environment]
699+
python-version = "3.11"
700+
```
701+
702+
```py
703+
def f(x: int, *args: str) -> int:
704+
return 1
705+
706+
def _(
707+
args1: list[int],
708+
args2: tuple[int],
709+
args3: tuple[int, int],
710+
args4: tuple[int, ...],
711+
args5: tuple[int, *tuple[str, ...]],
712+
args6: tuple[int, int, *tuple[str, ...]],
713+
) -> None:
714+
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
715+
reveal_type(f(*args1)) # revealed: int
716+
717+
# This shouldn't raise an error because the unpacking doesn't match the variadic parameter.
718+
reveal_type(f(*args2)) # revealed: int
719+
720+
# But, this should because the second tuple element is not assignable.
721+
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
722+
reveal_type(f(*args3)) # revealed: int
723+
724+
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
725+
reveal_type(f(*args4)) # revealed: int
726+
727+
# The first element of the tuple matches the required argument;
728+
# all subsequent elements match the variadic argument
729+
reveal_type(f(*args5)) # revealed: int
730+
731+
# error: [invalid-argument-type] "Argument to function `f` is incorrect: Expected `str`, found `int`"
732+
reveal_type(f(*args6)) # revealed: int
733+
```
734+
645735
### Keyword argument, positional-or-keyword parameter
646736

647737
```py

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 7 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use ruff_python_ast as ast;
66
use crate::Db;
77
use crate::types::KnownClass;
88
use crate::types::enums::{enum_member_literals, enum_metadata};
9-
use crate::types::tuple::{Tuple, TupleLength, TupleType};
9+
use crate::types::tuple::{Tuple, TupleType};
1010

1111
use super::Type;
1212

@@ -17,7 +17,7 @@ pub(crate) enum Argument<'a> {
1717
/// A positional argument.
1818
Positional,
1919
/// A starred positional argument (e.g. `*args`) containing the specified number of elements.
20-
Variadic(TupleLength),
20+
Variadic,
2121
/// A keyword argument (e.g. `a=1`).
2222
Keyword(&'a str),
2323
/// The double-starred keywords argument (e.g. `**kwargs`).
@@ -41,7 +41,6 @@ impl<'a, 'db> CallArguments<'a, 'db> {
4141
/// type of each splatted argument, so that we can determine its length. All other arguments
4242
/// will remain uninitialized as `Unknown`.
4343
pub(crate) fn from_arguments(
44-
db: &'db dyn Db,
4544
arguments: &'a ast::Arguments,
4645
mut infer_argument_type: impl FnMut(Option<&ast::Expr>, &ast::Expr) -> Type<'db>,
4746
) -> Self {
@@ -51,11 +50,7 @@ impl<'a, 'db> CallArguments<'a, 'db> {
5150
ast::ArgOrKeyword::Arg(arg) => match arg {
5251
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
5352
let ty = infer_argument_type(Some(arg), value);
54-
let length = ty
55-
.try_iterate(db)
56-
.map(|tuple| tuple.len())
57-
.unwrap_or(TupleLength::unknown());
58-
(Argument::Variadic(length), Some(ty))
53+
(Argument::Variadic, Some(ty))
5954
}
6055
_ => (Argument::Positional, None),
6156
},
@@ -203,25 +198,10 @@ impl<'a, 'db> CallArguments<'a, 'db> {
203198
for subtype in &expanded_types {
204199
let mut new_expanded_types = pre_expanded_types.to_vec();
205200
new_expanded_types[index] = Some(*subtype);
206-
207-
// Update the arguments list to handle variadic argument expansion
208-
let mut new_arguments = self.arguments.clone();
209-
if let Argument::Variadic(_) = self.arguments[index] {
210-
// If the argument corresponding to this type is variadic, we need to
211-
// update the tuple length because expanding could change the length.
212-
// For example, in `tuple[int] | tuple[int, int]`, the length of the
213-
// first type is 1, while the length of the second type is 2.
214-
if let Some(expanded_type) = new_expanded_types[index] {
215-
let length = expanded_type
216-
.try_iterate(db)
217-
.map(|tuple| tuple.len())
218-
.unwrap_or(TupleLength::unknown());
219-
new_arguments[index] = Argument::Variadic(length);
220-
}
221-
}
222-
223-
expanded_arguments
224-
.push(CallArguments::new(new_arguments, new_expanded_types));
201+
expanded_arguments.push(CallArguments::new(
202+
self.arguments.clone(),
203+
new_expanded_types,
204+
));
225205
}
226206
}
227207

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 42 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,24 +2135,36 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
21352135
Ok(())
21362136
}
21372137

2138+
/// Match a variadic argument to the remaining positional, standard or variadic parameters.
21382139
fn match_variadic(
21392140
&mut self,
21402141
db: &'db dyn Db,
21412142
argument_index: usize,
21422143
argument: Argument<'a>,
21432144
argument_type: Option<Type<'db>>,
2144-
length: TupleLength,
21452145
) -> Result<(), ()> {
21462146
let tuple = argument_type.map(|ty| ty.iterate(db));
2147-
let mut argument_types = match tuple.as_ref() {
2148-
Some(tuple) => Either::Left(tuple.all_elements().copied()),
2149-
None => Either::Right(std::iter::empty()),
2147+
let (mut argument_types, length, variable_element) = match tuple.as_ref() {
2148+
Some(tuple) => (
2149+
Either::Left(tuple.all_elements().copied()),
2150+
tuple.len(),
2151+
tuple.variable_element().copied(),
2152+
),
2153+
None => (
2154+
Either::Right(std::iter::empty()),
2155+
TupleLength::unknown(),
2156+
None,
2157+
),
21502158
};
21512159

21522160
// We must be able to match up the fixed-length portion of the argument with positional
21532161
// parameters, so we pass on any errors that occur.
21542162
for _ in 0..length.minimum() {
2155-
self.match_positional(argument_index, argument, argument_types.next())?;
2163+
self.match_positional(
2164+
argument_index,
2165+
argument,
2166+
argument_types.next().or(variable_element),
2167+
)?;
21562168
}
21572169

21582170
// If the tuple is variable-length, we assume that it will soak up all remaining positional
@@ -2163,7 +2175,24 @@ impl<'a, 'db> ArgumentMatcher<'a, 'db> {
21632175
.get_positional(self.next_positional)
21642176
.is_some()
21652177
{
2166-
self.match_positional(argument_index, argument, argument_types.next())?;
2178+
self.match_positional(
2179+
argument_index,
2180+
argument,
2181+
argument_types.next().or(variable_element),
2182+
)?;
2183+
}
2184+
}
2185+
2186+
// Finally, if there is a variadic parameter we can match any of the remaining unpacked
2187+
// argument types to it, but only if there is at least one remaining argument type. This is
2188+
// because a variadic parameter is optional, so if this was done unconditionally, ty could
2189+
// raise a false positive as "too many arguments".
2190+
if self.parameters.variadic().is_some() {
2191+
if let Some(argument_type) = argument_types.next().or(variable_element) {
2192+
self.match_positional(argument_index, argument, Some(argument_type))?;
2193+
for argument_type in argument_types {
2194+
self.match_positional(argument_index, argument, Some(argument_type))?;
2195+
}
21672196
}
21682197
}
21692198

@@ -2433,11 +2462,10 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24332462
self.enumerate_argument_types()
24342463
{
24352464
match argument {
2436-
Argument::Variadic(_) => self.check_variadic_argument_type(
2465+
Argument::Variadic => self.check_variadic_argument_type(
24372466
argument_index,
24382467
adjusted_argument_index,
24392468
argument,
2440-
argument_type,
24412469
),
24422470
Argument::Keywords => self.check_keyword_variadic_argument_type(
24432471
argument_index,
@@ -2465,37 +2493,15 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
24652493
argument_index: usize,
24662494
adjusted_argument_index: Option<usize>,
24672495
argument: Argument<'a>,
2468-
argument_type: Type<'db>,
24692496
) {
2470-
// If the argument is splatted, convert its type into a tuple describing the splatted
2471-
// elements. For tuples, we don't have to do anything! For other types, we treat it as
2472-
// an iterator, and create a homogeneous tuple of its output type, since we don't know
2473-
// how many elements the iterator will produce.
2474-
let argument_types = argument_type.iterate(self.db);
2475-
2476-
// Resize the tuple of argument types to line up with the number of parameters this
2477-
// argument was matched against. If parameter matching succeeded, then we can (TODO:
2478-
// should be able to, see above) guarantee that all of the required elements of the
2479-
// splatted tuple will have been matched with a parameter. But if parameter matching
2480-
// failed, there might be more required elements. That means we can't use
2481-
// TupleLength::Fixed below, because we would otherwise get a "too many values" error
2482-
// when parameter matching failed.
2483-
let desired_size =
2484-
TupleLength::Variable(self.argument_matches[argument_index].parameters.len(), 0);
2485-
let argument_types = argument_types
2486-
.resize(self.db, desired_size)
2487-
.expect("argument type should be consistent with its arity");
2488-
2489-
// Check the types by zipping through the splatted argument types and their matched
2490-
// parameters.
2491-
for (argument_type, parameter_index) in
2492-
(argument_types.all_elements()).zip(&self.argument_matches[argument_index].parameters)
2497+
for (parameter_index, variadic_argument_type) in
2498+
self.argument_matches[argument_index].iter()
24932499
{
24942500
self.check_argument_type(
24952501
adjusted_argument_index,
24962502
argument,
2497-
*argument_type,
2498-
*parameter_index,
2503+
variadic_argument_type.unwrap_or_else(Type::unknown),
2504+
parameter_index,
24992505
);
25002506
}
25012507
}
@@ -2711,9 +2717,8 @@ impl<'db> Binding<'db> {
27112717
Argument::Keyword(name) => {
27122718
let _ = matcher.match_keyword(argument_index, argument, None, name);
27132719
}
2714-
Argument::Variadic(length) => {
2715-
let _ =
2716-
matcher.match_variadic(db, argument_index, argument, argument_type, length);
2720+
Argument::Variadic => {
2721+
let _ = matcher.match_variadic(db, argument_index, argument, argument_type);
27172722
}
27182723
Argument::Keywords => {
27192724
keywords_arguments.push((argument_index, argument_type));

crates/ty_python_semantic/src/types/ide_support.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,7 @@ pub fn call_signature_details<'db>(
874874
// Use into_callable to handle all the complex type conversions
875875
if let Some(callable_type) = func_type.into_callable(db) {
876876
let call_arguments =
877-
CallArguments::from_arguments(db, &call_expr.arguments, |_, splatted_value| {
877+
CallArguments::from_arguments(&call_expr.arguments, |_, splatted_value| {
878878
splatted_value.inferred_type(model)
879879
});
880880
let bindings = callable_type

crates/ty_python_semantic/src/types/infer/builder.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1733,7 +1733,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
17331733
let previous_deferred_state =
17341734
std::mem::replace(&mut self.deferred_state, in_stub.into());
17351735
let mut call_arguments =
1736-
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
1736+
CallArguments::from_arguments(arguments, |argument, splatted_value| {
17371737
let ty = self.infer_expression(splatted_value, TypeContext::default());
17381738
if let Some(argument) = argument {
17391739
self.store_expression_type(argument, ty);
@@ -5831,7 +5831,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
58315831
// arguments after matching them to parameters, but before checking that the argument types
58325832
// are assignable to any parameter annotations.
58335833
let mut call_arguments =
5834-
CallArguments::from_arguments(self.db(), arguments, |argument, splatted_value| {
5834+
CallArguments::from_arguments(arguments, |argument, splatted_value| {
58355835
let ty = self.infer_expression(splatted_value, TypeContext::default());
58365836
if let Some(argument) = argument {
58375837
self.store_expression_type(argument, ty);

crates/ty_python_semantic/src/types/tuple.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,14 @@ impl<T> Tuple<T> {
970970
FixedLengthTuple::from_elements(elements).into()
971971
}
972972

973+
/// Returns the variable-length element of this tuple, if it has one.
974+
pub(crate) fn variable_element(&self) -> Option<&T> {
975+
match self {
976+
Tuple::Fixed(_) => None,
977+
Tuple::Variable(tuple) => Some(&tuple.variable),
978+
}
979+
}
980+
973981
/// Returns an iterator of all of the fixed-length element types of this tuple.
974982
pub(crate) fn fixed_elements(&self) -> impl Iterator<Item = &T> + '_ {
975983
match self {

0 commit comments

Comments
 (0)