Skip to content

Commit c116820

Browse files
committed
infer correct arity for splatted tuples
1 parent cf39e0d commit c116820

File tree

3 files changed

+149
-23
lines changed

3 files changed

+149
-23
lines changed

crates/ty_python_semantic/resources/mdtest/call/function.md

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def _(args: tuple[int, ...]) -> None:
9898
takes_at_least_two(*args)
9999
```
100100

101-
### Known argument length
101+
### Fixed-length tuple argument
102102

103103
```py
104104
def takes_zero() -> None: ...
@@ -109,11 +109,114 @@ def takes_at_least_one(x: int, *args) -> None: ...
109109
def takes_at_least_two(x: int, y: int, *args) -> None: ...
110110

111111
def _(args: tuple[int]) -> None:
112+
# error: [too-many-positional-arguments]
112113
takes_zero(*args)
113114
takes_one(*args)
115+
# error: [missing-argument]
114116
takes_two(*args)
115117
takes_at_least_zero(*args)
116118
takes_at_least_one(*args)
119+
# error: [missing-argument]
120+
takes_at_least_two(*args)
121+
122+
def _(args: tuple[int, int]) -> None:
123+
# error: [too-many-positional-arguments]
124+
takes_zero(*args)
125+
# error: [too-many-positional-arguments]
126+
takes_one(*args)
127+
takes_two(*args)
128+
takes_at_least_zero(*args)
129+
takes_at_least_one(*args)
130+
takes_at_least_two(*args)
131+
132+
def _(args: tuple[int, str]) -> None:
133+
# error: [too-many-positional-arguments]
134+
takes_zero(*args)
135+
# error: [too-many-positional-arguments]
136+
takes_one(*args)
137+
# error: [invalid-argument-type]
138+
takes_two(*args)
139+
takes_at_least_zero(*args)
140+
takes_at_least_one(*args)
141+
# error: [invalid-argument-type]
142+
takes_at_least_two(*args)
143+
```
144+
145+
### Mixed tuple argument
146+
147+
```toml
148+
[environment]
149+
python-version = "3.11"
150+
```
151+
152+
```py
153+
def takes_zero() -> None: ...
154+
def takes_one(x: int) -> None: ...
155+
def takes_two(x: int, y: int) -> None: ...
156+
def takes_at_least_zero(*args) -> None: ...
157+
def takes_at_least_one(x: int, *args) -> None: ...
158+
def takes_at_least_two(x: int, y: int, *args) -> None: ...
159+
160+
def _(args: tuple[int, *tuple[int, ...]]) -> None:
161+
# error: [too-many-positional-arguments]
162+
takes_zero(*args)
163+
takes_one(*args)
164+
takes_two(*args)
165+
takes_at_least_zero(*args)
166+
takes_at_least_one(*args)
167+
takes_at_least_two(*args)
168+
169+
def _(args: tuple[int, *tuple[str, ...]]) -> None:
170+
# error: [too-many-positional-arguments]
171+
takes_zero(*args)
172+
takes_one(*args)
173+
# error: [invalid-argument-type]
174+
takes_two(*args)
175+
takes_at_least_zero(*args)
176+
takes_at_least_one(*args)
177+
# error: [invalid-argument-type]
178+
takes_at_least_two(*args)
179+
180+
def _(args: tuple[int, int, *tuple[int, ...]]) -> None:
181+
# error: [too-many-positional-arguments]
182+
takes_zero(*args)
183+
# error: [too-many-positional-arguments]
184+
takes_one(*args)
185+
takes_two(*args)
186+
takes_at_least_zero(*args)
187+
takes_at_least_one(*args)
188+
takes_at_least_two(*args)
189+
190+
def _(args: tuple[int, int, *tuple[str, ...]]) -> None:
191+
# error: [too-many-positional-arguments]
192+
takes_zero(*args)
193+
# error: [too-many-positional-arguments]
194+
takes_one(*args)
195+
takes_two(*args)
196+
takes_at_least_zero(*args)
197+
takes_at_least_one(*args)
198+
takes_at_least_two(*args)
199+
200+
def _(args: tuple[int, *tuple[int, ...], int]) -> None:
201+
# error: [too-many-positional-arguments]
202+
takes_zero(*args)
203+
# error: [too-many-positional-arguments]
204+
takes_one(*args)
205+
takes_two(*args)
206+
takes_at_least_zero(*args)
207+
takes_at_least_one(*args)
208+
takes_at_least_two(*args)
209+
210+
def _(args: tuple[int, *tuple[str, ...], int]) -> None:
211+
# error: [too-many-positional-arguments]
212+
takes_zero(*args)
213+
# error: [too-many-positional-arguments]
214+
takes_one(*args)
215+
# error: [invalid-argument-type]
216+
takes_two(*args)
217+
takes_at_least_zero(*args)
218+
takes_at_least_one(*args)
219+
# error: [invalid-argument-type]
117220
takes_at_least_two(*args)
118221
```
119222

crates/ty_python_semantic/src/types/call/arguments.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,9 @@ impl<'a> CallArguments<'a> {
3737
}
3838
}
3939

40-
impl<'a> FromIterator<Argument<'a>> for CallArguments<'a> {
41-
fn from_iter<T: IntoIterator<Item = Argument<'a>>>(iter: T) -> Self {
42-
Self(iter.into_iter().collect())
40+
impl<'a> From<Vec<Argument<'a>>> for CallArguments<'a> {
41+
fn from(arguments: Vec<Argument<'a>>) -> Self {
42+
Self(arguments)
4343
}
4444
}
4545

crates/ty_python_semantic/src/types/infer.rs

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1862,9 +1862,9 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
18621862
self.infer_type_parameters(type_params);
18631863

18641864
if let Some(arguments) = class.arguments.as_deref() {
1865-
let call_arguments = Self::parse_arguments(arguments);
1865+
let (call_arguments, argument_types) = self.parse_arguments(arguments);
18661866
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
1867-
self.infer_argument_types(arguments, call_arguments, &argument_forms);
1867+
self.infer_argument_types(arguments, call_arguments, argument_types, &argument_forms);
18681868
}
18691869
}
18701870

@@ -4536,48 +4536,63 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
45364536
self.infer_expression(expression)
45374537
}
45384538

4539-
fn parse_arguments(arguments: &ast::Arguments) -> CallArguments<'_> {
4540-
arguments
4539+
fn parse_arguments<'a>(
4540+
&mut self,
4541+
arguments: &'a ast::Arguments,
4542+
) -> (CallArguments<'a>, Vec<Option<Type<'db>>>) {
4543+
let (arguments, types): (Vec<_>, Vec<_>) = arguments
45414544
.arguments_source_order()
45424545
.map(|arg_or_keyword| {
45434546
match arg_or_keyword {
45444547
ast::ArgOrKeyword::Arg(arg) => match arg {
4545-
ast::Expr::Starred(ast::ExprStarred { .. }) => {
4546-
Argument::Variadic(TupleLength::unknown())
4548+
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
4549+
let ty = self.infer_expression(value);
4550+
self.store_expression_type(arg, ty);
4551+
let length = match ty {
4552+
Type::Tuple(tuple) => tuple.tuple(self.db()).len(),
4553+
// TODO: have `Type::try_iterator` return a tuple spec, and use its
4554+
// length as this argument's arity
4555+
_ => TupleLength::unknown(),
4556+
};
4557+
(Argument::Variadic(length), Some(ty))
45474558
}
45484559
// TODO diagnostic if after a keyword argument
4549-
_ => Argument::Positional,
4560+
_ => (Argument::Positional, None),
45504561
},
45514562
ast::ArgOrKeyword::Keyword(ast::Keyword { arg, .. }) => {
45524563
if let Some(arg) = arg {
4553-
Argument::Keyword(&arg.id)
4564+
(Argument::Keyword(&arg.id), None)
45544565
} else {
45554566
// TODO diagnostic if not last
4556-
Argument::Keywords
4567+
(Argument::Keywords, None)
45574568
}
45584569
}
45594570
}
45604571
})
4561-
.collect()
4572+
.unzip();
4573+
let arguments = CallArguments::from(arguments);
4574+
(arguments, types)
45624575
}
45634576

45644577
fn infer_argument_types<'a>(
45654578
&mut self,
45664579
ast_arguments: &ast::Arguments,
45674580
arguments: CallArguments<'a>,
4581+
argument_types: Vec<Option<Type<'db>>>,
45684582
argument_forms: &[Option<ParameterForm>],
45694583
) -> CallArgumentTypes<'a, 'db> {
45704584
let mut ast_arguments = ast_arguments.arguments_source_order();
45714585
CallArgumentTypes::new(arguments, |index, _| {
4586+
if let Some(argument_type) = argument_types[index] {
4587+
return argument_type;
4588+
}
45724589
let arg_or_keyword = ast_arguments
45734590
.next()
45744591
.expect("argument lists should have consistent lengths");
45754592
match arg_or_keyword {
45764593
ast::ArgOrKeyword::Arg(arg) => match arg {
4577-
ast::Expr::Starred(ast::ExprStarred { value, .. }) => {
4578-
let ty = self.infer_argument_type(value, argument_forms[index]);
4579-
self.store_expression_type(arg, ty);
4580-
ty
4594+
ast::Expr::Starred(ast::ExprStarred { .. }) => {
4595+
panic!("should have already inferred a type for splatted argument");
45814596
}
45824597
_ => self.infer_argument_type(arg, argument_forms[index]),
45834598
},
@@ -5286,7 +5301,7 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
52865301
// We don't call `Type::try_call`, because we want to perform type inference on the
52875302
// arguments after matching them to parameters, but before checking that the argument types
52885303
// are assignable to any parameter annotations.
5289-
let call_arguments = Self::parse_arguments(arguments);
5304+
let (call_arguments, argument_types) = self.parse_arguments(arguments);
52905305
let callable_type = self.infer_expression(func);
52915306

52925307
if let Type::FunctionLiteral(function) = callable_type {
@@ -5358,8 +5373,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53585373
.is_none_or(|enum_class| !class.is_subclass_of(self.db(), enum_class))
53595374
{
53605375
let argument_forms = vec![Some(ParameterForm::Value); call_arguments.len()];
5361-
let call_argument_types =
5362-
self.infer_argument_types(arguments, call_arguments, &argument_forms);
5376+
let call_argument_types = self.infer_argument_types(
5377+
arguments,
5378+
call_arguments,
5379+
argument_types,
5380+
&argument_forms,
5381+
);
53635382

53645383
return callable_type
53655384
.try_call_constructor(self.db(), call_argument_types)
@@ -5373,8 +5392,12 @@ impl<'db, 'ast> TypeInferenceBuilder<'db, 'ast> {
53735392
let bindings = callable_type
53745393
.bindings(self.db())
53755394
.match_parameters(&call_arguments);
5376-
let call_argument_types =
5377-
self.infer_argument_types(arguments, call_arguments, &bindings.argument_forms);
5395+
let call_argument_types = self.infer_argument_types(
5396+
arguments,
5397+
call_arguments,
5398+
argument_types,
5399+
&bindings.argument_forms,
5400+
);
53785401

53795402
let mut bindings = match bindings.check_types(self.db(), &call_argument_types) {
53805403
Ok(bindings) => bindings,

0 commit comments

Comments
 (0)