Skip to content

Commit cd46e48

Browse files
committed
Include it in argument type expansion loop
1 parent 7fedb2c commit cd46e48

File tree

1 file changed

+64
-50
lines changed
  • crates/ty_python_semantic/src/types/call

1 file changed

+64
-50
lines changed

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ use crate::types::function::{DataclassTransformerParams, FunctionDecorators, Kno
2323
use crate::types::generics::{Specialization, SpecializationBuilder, SpecializationError};
2424
use crate::types::signatures::{Parameter, ParameterForm};
2525
use crate::types::{
26-
BoundMethodType, ClassLiteral, DataclassParams, IntersectionBuilder, KnownClass,
27-
KnownInstanceType, MethodWrapperKind, PropertyInstanceType, SpecialFormType, TupleType,
28-
TypeMapping, UnionType, WrapperDescriptorKind, ide_support, todo_type,
26+
BoundMethodType, ClassLiteral, DataclassParams, KnownClass, KnownInstanceType,
27+
MethodWrapperKind, PropertyInstanceType, SpecialFormType, TupleType, TypeMapping, UnionType,
28+
WrapperDescriptorKind, ide_support, todo_type,
2929
};
3030
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, SubDiagnostic};
3131
use ruff_python_ast as ast;
@@ -1176,7 +1176,7 @@ impl<'db> CallableBinding<'db> {
11761176
}
11771177
};
11781178

1179-
let snapshotter = MatchingOverloadsSnapshotter::new(matching_overload_indexes);
1179+
let snapshotter = CallableBindingSnapshotter::new(matching_overload_indexes);
11801180

11811181
// State of the bindings _before_ evaluating (type checking) the matching overloads using
11821182
// the non-expanded argument types.
@@ -1198,11 +1198,6 @@ impl<'db> CallableBinding<'db> {
11981198
}
11991199
MatchingOverloadIndex::Multiple(indexes) => {
12001200
// If two or more candidate overloads remain, proceed to step 4.
1201-
tracing::info!(
1202-
"Multiple overloads match: {:?}, filtering based on Any",
1203-
indexes
1204-
);
1205-
12061201
// TODO: Step 4
12071202

12081203
// Step 5
@@ -1234,7 +1229,7 @@ impl<'db> CallableBinding<'db> {
12341229
// This is the merged state of the bindings after evaluating all of the expanded
12351230
// argument lists. This will be the final state to restore the bindings to if all of
12361231
// the expanded argument lists evaluated successfully.
1237-
let mut merged_evaluation_state: Option<MatchingOverloadsSnapshot<'db>> = None;
1232+
let mut merged_evaluation_state: Option<CallableBindingSnapshot<'db>> = None;
12381233

12391234
let mut return_types = Vec::new();
12401235

@@ -1250,10 +1245,16 @@ impl<'db> CallableBinding<'db> {
12501245
MatchingOverloadIndex::Single(index) => {
12511246
Some(self.overloads[index].return_type())
12521247
}
1253-
MatchingOverloadIndex::Multiple(index) => {
1254-
// TODO: Step 4 and Step 5 goes here... but for now we just use the return
1255-
// type of the first matched overload.
1256-
Some(self.overloads[index[0]].return_type())
1248+
MatchingOverloadIndex::Multiple(matching_overload_indexes) => {
1249+
// TODO: Step 4
1250+
1251+
self.filter_overloads_using_any_or_unknown(
1252+
db,
1253+
expanded_argument_types,
1254+
&matching_overload_indexes,
1255+
);
1256+
1257+
Some(self.return_type())
12571258
}
12581259
};
12591260

@@ -1283,6 +1284,15 @@ impl<'db> CallableBinding<'db> {
12831284
}
12841285

12851286
if return_types.len() == expanded_argument_lists.len() {
1287+
// Restore the bindings state to the one that merges the bindings state evaluating
1288+
// each of the expanded argument list.
1289+
//
1290+
// Note that this needs to happen *before* setting the return type, because this
1291+
// will restore the return type to the one before argument type expansion.
1292+
if let Some(merged_evaluation_state) = merged_evaluation_state {
1293+
snapshotter.restore(self, merged_evaluation_state);
1294+
}
1295+
12861296
// If the number of return types is equal to the number of expanded argument lists,
12871297
// they all evaluated successfully. So, we need to combine their return types by
12881298
// union to determine the final return type.
@@ -1291,12 +1301,6 @@ impl<'db> CallableBinding<'db> {
12911301
UnionType::from_elements(db, return_types),
12921302
));
12931303

1294-
// Restore the bindings state to the one that merges the bindings state evaluating
1295-
// each of the expanded argument list.
1296-
if let Some(merged_evaluation_state) = merged_evaluation_state {
1297-
snapshotter.restore(self, merged_evaluation_state);
1298-
}
1299-
13001304
return;
13011305
}
13021306
}
@@ -1342,32 +1346,33 @@ impl<'db> CallableBinding<'db> {
13421346
self.overloads[*current_index].mark_as_unmatched_overload();
13431347
continue;
13441348
}
1345-
let mut unions = Vec::with_capacity(argument_types.len());
1349+
let mut parameter_types = Vec::with_capacity(argument_types.len());
13461350
for argument_index in 0..argument_types.len() {
1347-
let mut union = vec![];
1351+
// The parameter types at the current argument index.
1352+
let mut current_parameter_types = vec![];
13481353
for overload_index in &matching_overload_indexes[..=upto] {
13491354
let overload = &self.overloads[*overload_index];
13501355
let Some(parameter_index) = overload.argument_parameters[argument_index] else {
13511356
// There is no parameter for this argument in this overload.
13521357
continue;
13531358
};
1354-
// TODO: For an unannotated `self` parameter, the type should be `typing.Self`
1355-
// while for other unannotated parameters, the type should be `Unknown`
1356-
let ty = overload.signature.parameters()[parameter_index]
1359+
// TODO: For an unannotated `self` / `cls` parameter, the type should be
1360+
// `typing.Self` / `type[typing.Self]`
1361+
let parameter_type = overload.signature.parameters()[parameter_index]
13571362
.annotated_type()
13581363
.unwrap_or(Type::unknown());
1359-
union.push(ty);
1364+
current_parameter_types.push(parameter_type);
13601365
}
1361-
if union.is_empty() {
1366+
if current_parameter_types.is_empty() {
13621367
continue;
13631368
}
1364-
unions.push(UnionType::from_elements(db, union));
1369+
parameter_types.push(UnionType::from_elements(db, current_parameter_types));
13651370
}
1366-
if unions.len() != argument_types.len() {
1371+
if parameter_types.len() != argument_types.len() {
13671372
continue;
13681373
}
13691374
if top_materialized_argument_type
1370-
.is_assignable_to(db, TupleType::from_elements(db, unions))
1375+
.is_assignable_to(db, TupleType::from_elements(db, parameter_types))
13711376
{
13721377
filter_remaining_overloads = true;
13731378
}
@@ -1967,6 +1972,7 @@ impl<'db> Binding<'db> {
19671972
.map(|(arg_and_type, _)| arg_and_type)
19681973
}
19691974

1975+
/// Mark this overload binding as an unmatched overload.
19701976
fn mark_as_unmatched_overload(&mut self) {
19711977
self.errors.push(BindingError::UnmatchedOverload);
19721978
}
@@ -2031,23 +2037,27 @@ struct BindingSnapshot<'db> {
20312037
errors: Vec<BindingError<'db>>,
20322038
}
20332039

2034-
/// Represents the snapshot of the matched overload bindings.
2035-
///
2036-
/// The reason that this only contains the matched overloads are:
2037-
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
2038-
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
2039-
/// expanded argument lists
20402040
#[derive(Clone, Debug)]
2041-
struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>);
2041+
struct CallableBindingSnapshot<'db> {
2042+
overload_return_type: Option<OverloadCallReturnType<'db>>,
20422043

2043-
impl<'db> MatchingOverloadsSnapshot<'db> {
2044+
/// Represents the snapshot of the matched overload bindings.
2045+
///
2046+
/// The reason that this only contains the matched overloads are:
2047+
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
2048+
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all
2049+
/// the expanded argument lists
2050+
matching_overloads: Vec<(usize, BindingSnapshot<'db>)>,
2051+
}
2052+
2053+
impl<'db> CallableBindingSnapshot<'db> {
20442054
/// Update the state of the matched overload bindings in this snapshot with the current
20452055
/// state in the given `binding`.
20462056
fn update(&mut self, binding: &CallableBinding<'db>) {
20472057
// Here, the `snapshot` is the state of this binding for the previous argument list and
20482058
// `binding` would contain the state after evaluating the current argument list.
20492059
for (snapshot, binding) in self
2050-
.0
2060+
.matching_overloads
20512061
.iter_mut()
20522062
.map(|(index, snapshot)| (snapshot, &binding.overloads[*index]))
20532063
{
@@ -2083,37 +2093,40 @@ impl<'db> MatchingOverloadsSnapshot<'db> {
20832093

20842094
/// A helper to take snapshots of the matched overload bindings for the current state of the
20852095
/// bindings.
2086-
struct MatchingOverloadsSnapshotter(Vec<usize>);
2096+
struct CallableBindingSnapshotter(Vec<usize>);
20872097

2088-
impl MatchingOverloadsSnapshotter {
2098+
impl CallableBindingSnapshotter {
20892099
/// Creates a new snapshotter for the given indexes of the matched overloads.
20902100
fn new(indexes: Vec<usize>) -> Self {
20912101
debug_assert!(indexes.len() > 1);
2092-
MatchingOverloadsSnapshotter(indexes)
2102+
CallableBindingSnapshotter(indexes)
20932103
}
20942104

20952105
/// Takes a snapshot of the current state of the matched overload bindings.
20962106
///
20972107
/// # Panics
20982108
///
20992109
/// Panics if the indexes of the matched overloads are not valid for the given binding.
2100-
fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> {
2101-
MatchingOverloadsSnapshot(
2102-
self.0
2110+
fn take<'db>(&self, binding: &CallableBinding<'db>) -> CallableBindingSnapshot<'db> {
2111+
CallableBindingSnapshot {
2112+
overload_return_type: binding.overload_call_return_type,
2113+
matching_overloads: self
2114+
.0
21032115
.iter()
21042116
.map(|index| (*index, binding.overloads[*index].snapshot()))
21052117
.collect(),
2106-
)
2118+
}
21072119
}
21082120

21092121
/// Restores the state of the matched overload bindings from the given snapshot.
21102122
fn restore<'db>(
21112123
&self,
21122124
binding: &mut CallableBinding<'db>,
2113-
snapshot: MatchingOverloadsSnapshot<'db>,
2125+
snapshot: CallableBindingSnapshot<'db>,
21142126
) {
2115-
debug_assert_eq!(self.0.len(), snapshot.0.len());
2116-
for (index, snapshot) in snapshot.0 {
2127+
debug_assert_eq!(self.0.len(), snapshot.matching_overloads.len());
2128+
binding.overload_call_return_type = snapshot.overload_return_type;
2129+
for (index, snapshot) in snapshot.matching_overloads {
21172130
binding.overloads[index].restore(snapshot);
21182131
}
21192132
}
@@ -2256,7 +2269,8 @@ pub(crate) enum BindingError<'db> {
22562269
/// We use this variant to report errors in `property.__get__` and `property.__set__`, which
22572270
/// can occur when the call to the underlying getter/setter fails.
22582271
InternalCallError(&'static str),
2259-
/// This overload of the callable does not match the arguments.
2272+
/// This overload binding of the callable does not match the arguments.
2273+
// TODO: We could expand this with an enum to specify why the overload is unmatched.
22602274
UnmatchedOverload,
22612275
}
22622276

0 commit comments

Comments
 (0)