Skip to content

Commit fc725c2

Browse files
committed
State merging
1 parent f9b13c7 commit fc725c2

File tree

1 file changed

+164
-59
lines changed
  • crates/ty_python_semantic/src/types/call

1 file changed

+164
-59
lines changed

crates/ty_python_semantic/src/types/call/bind.rs

Lines changed: 164 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use crate::types::signatures::{Parameter, ParameterForm};
2323
use crate::types::{
2424
BoundMethodType, DataclassParams, DataclassTransformerParams, FunctionDecorators, FunctionType,
2525
KnownClass, KnownFunction, KnownInstanceType, MethodWrapperKind, PropertyInstanceType,
26-
SpecialFormType, TupleType, TypeMapping, UnionBuilder, UnionType, WrapperDescriptorKind,
27-
ide_support, todo_type,
26+
SpecialFormType, TupleType, TypeMapping, UnionType, WrapperDescriptorKind, ide_support,
27+
todo_type,
2828
};
2929
use ruff_db::diagnostic::{Annotation, Diagnostic, Severity, SubDiagnostic};
3030
use ruff_python_ast as ast;
@@ -1105,32 +1105,12 @@ impl<'db> CallableBinding<'db> {
11051105
}
11061106

11071107
fn check_types(&mut self, db: &'db dyn Db, argument_types: &CallArgumentTypes<'_, 'db>) {
1108-
/// Represents the snapshot of the overload bindings.
1109-
struct OverloadSnapshots<'db>(SmallVec<[BindingSnapshot<'db>; 1]>);
1110-
1111-
/// Takes a snapshot of the current state of the overload bindings.
1112-
fn snapshot<'db>(binding: &CallableBinding<'db>) -> OverloadSnapshots<'db> {
1113-
OverloadSnapshots(binding.overloads.iter().map(Binding::snapshot).collect())
1114-
}
1115-
1116-
/// Restores the state of the overload bindings from the given snapshots.
1117-
fn restore<'db>(binding: &mut CallableBinding<'db>, snapshots: OverloadSnapshots<'db>) {
1118-
debug_assert_eq!(binding.overloads.len(), snapshots.0.len());
1119-
binding
1120-
.overloads
1121-
.iter_mut()
1122-
.zip(snapshots.0)
1123-
.for_each(|(binding, snapshot)| {
1124-
binding.restore(snapshot);
1125-
});
1126-
}
1127-
11281108
// If this callable is a bound method, prepend the self instance onto the arguments list
11291109
// before checking.
11301110
let argument_types = argument_types.with_self(self.bound_type);
11311111

11321112
// Step 1: Check the result of the arity check which is done by `match_parameters`
1133-
match self.matching_overload_index() {
1113+
let matching_overload_indexes = match self.matching_overload_index() {
11341114
MatchingOverloadIndex::None => {
11351115
// If no candidate overloads remain from the arity check, we can stop here. We
11361116
// still perform type checking for non-overloaded function to provide better user
@@ -1142,19 +1122,27 @@ impl<'db> CallableBinding<'db> {
11421122
}
11431123
MatchingOverloadIndex::Single(index) => {
11441124
// If only one candidate overload remains, it is the winning match.
1125+
// TODO: Evaluate it as a regular (non-overloaded) call. This means that any
1126+
// diagnostics reported in this check should be reported directly instead of
1127+
// reporting it as `no-matching-overload`.
11451128
self.overloads[index].check_types(
11461129
db,
11471130
argument_types.as_ref(),
11481131
argument_types.types(),
11491132
);
11501133
return;
11511134
}
1152-
MatchingOverloadIndex::Multiple(_) => {
1135+
MatchingOverloadIndex::Multiple(indexes) => {
11531136
// If two or more candidate overloads remain, proceed to step 2.
1137+
indexes
11541138
}
1155-
}
1139+
};
11561140

1157-
let pre_evaluation_snapshot = snapshot(self);
1141+
let snapshotter = Snapshotter::new(matching_overload_indexes);
1142+
1143+
// State of the bindings _before_ evaluating (type checking) the matching overloads using
1144+
// the non-expanded argument types.
1145+
let pre_evaluation_snapshot = snapshotter.take(self);
11581146

11591147
// Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
11601148
// whether it is compatible with the supplied argument list.
@@ -1180,23 +1168,31 @@ impl<'db> CallableBinding<'db> {
11801168

11811169
// Step 3: Perform "argument type expansion". Reference:
11821170
// https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
1183-
11841171
let mut expansions = argument_types.expand(db).peekable();
11851172

11861173
if expansions.peek().is_none() {
1187-
// Return early if there are no argument types to expand. This is especially useful to
1188-
// avoid restoring the bindings state.
1174+
// Return early if there are no argument types to expand.
11891175
return;
11901176
}
11911177

1192-
restore(self, pre_evaluation_snapshot);
1178+
// State of the bindings _after_ evaluating (type checking) the matching overloads using
1179+
// the non-expanded argument types.
1180+
let post_evaluation_snapshot = snapshotter.take(self);
1181+
1182+
// Restore the bindings state to the one prior to the type checking step in preparation
1183+
// for evaluating the expanded argument lists.
1184+
snapshotter.restore(self, pre_evaluation_snapshot);
11931185

11941186
for expanded_argument_lists in expansions {
1195-
let mut all_evaluated_successfully = true;
1196-
let mut union_builder = UnionBuilder::new(db);
1187+
// This is the merged state of the bindings after evaluating all of the expanded
1188+
// argument lists. This will be the final state to restore the bindings to if all of
1189+
// the expanded argument lists evaluated successfully.
1190+
let mut merged_evaluation_state: Option<MatchingOverloadsSnapshot<'db>> = None;
1191+
1192+
let mut return_types = Vec::new();
11971193

11981194
for expanded_argument_types in &expanded_argument_lists {
1199-
let pre_evaluation_snapshot = snapshot(self);
1195+
let pre_evaluation_snapshot = snapshotter.take(self);
12001196

12011197
for (_, overload) in self.matching_overloads_mut() {
12021198
overload.check_types(db, argument_types.as_ref(), expanded_argument_types);
@@ -1210,30 +1206,45 @@ impl<'db> CallableBinding<'db> {
12101206
MatchingOverloadIndex::Multiple(index) => {
12111207
// TODO: Step 4 and Step 5 goes here... but for now we just use the return
12121208
// type of the first matched overload.
1213-
Some(self.overloads[index].return_type())
1209+
Some(self.overloads[index[0]].return_type())
12141210
}
12151211
};
12161212

1213+
if let Some(merged_evaluation_state) = merged_evaluation_state.as_mut() {
1214+
// Update the evaluation state with the state of the bindings after evaluating
1215+
// the current argument list.
1216+
merged_evaluation_state.update(db, self);
1217+
} else {
1218+
// Initialize the merged evaluation state with the state of the bindings after
1219+
// evaluating the _first_ argument list. It only contains the snapshots for the
1220+
// matching overloads from the pre-evaluation snapshot.
1221+
merged_evaluation_state = Some(snapshotter.take(self));
1222+
}
1223+
12171224
// Restore the bindings state before evaluating the next argument list.
1218-
restore(self, pre_evaluation_snapshot);
1225+
snapshotter.restore(self, pre_evaluation_snapshot);
12191226

12201227
if let Some(return_type) = return_type {
1221-
union_builder.add_in_place(return_type);
1228+
return_types.push(return_type);
12221229
} else {
12231230
// No need to check the remaining argument lists if the current argument list
12241231
// doesn't evaluate successfully. Move on to expanding the next argument type.
1225-
all_evaluated_successfully = false;
12261232
break;
12271233
}
12281234
}
12291235

1230-
if all_evaluated_successfully {
1236+
if return_types.len() == expanded_argument_lists.len() {
12311237
// If the number of return types is equal to the number of expanded argument lists,
12321238
// they all evaluated successfully. So, we need to combine their return types by
12331239
// union to determine the final return type.
1234-
//
1235-
// TODO: What should be the state of the bindings at this point?
1236-
self.return_type = Some(union_builder.build());
1240+
self.return_type = Some(UnionType::from_elements(db, return_types));
1241+
1242+
// Restore the bindings state to the one that merges the bindings state evaluating
1243+
// each of the expanded argument list.
1244+
if let Some(merged_evaluation_state) = merged_evaluation_state {
1245+
snapshotter.restore(self, merged_evaluation_state);
1246+
}
1247+
12371248
return;
12381249
}
12391250
}
@@ -1245,9 +1256,7 @@ impl<'db> CallableBinding<'db> {
12451256
//
12461257
// This will be skipped if there are no argument types to expand because of the early
12471258
// return.
1248-
for (_, overload) in self.matching_overloads_mut() {
1249-
overload.check_types(db, argument_types.as_ref(), argument_types.types());
1250-
}
1259+
snapshotter.restore(self, post_evaluation_snapshot);
12511260
}
12521261

12531262
fn as_result(&self) -> Result<(), CallErrorKind> {
@@ -1281,11 +1290,15 @@ impl<'db> CallableBinding<'db> {
12811290
let mut matching_overloads = self.matching_overloads();
12821291
match matching_overloads.next() {
12831292
None => MatchingOverloadIndex::None,
1284-
Some((index, _)) => {
1285-
if matching_overloads.next().is_some() {
1286-
MatchingOverloadIndex::Multiple(index)
1293+
Some((first, _)) => {
1294+
if let Some((second, _)) = matching_overloads.next() {
1295+
let mut indexes = vec![first, second];
1296+
for (index, _) in matching_overloads {
1297+
indexes.push(index);
1298+
}
1299+
MatchingOverloadIndex::Multiple(indexes)
12871300
} else {
1288-
MatchingOverloadIndex::Single(index)
1301+
MatchingOverloadIndex::Single(first)
12891302
}
12901303
}
12911304
}
@@ -1469,15 +1482,11 @@ enum MatchingOverloadIndex {
14691482
/// No matching overloads found.
14701483
None,
14711484

1472-
/// Exactly one matching overload found.
1473-
///
1474-
/// The index is the position of the matching overload.
1485+
/// Exactly one matching overload found at the given index.
14751486
Single(usize),
14761487

1477-
/// Multiple matching overloads found.
1478-
///
1479-
/// The index is the position of the first matching overload.
1480-
Multiple(usize),
1488+
/// Multiple matching overloads found at the given indexes.
1489+
Multiple(Vec<usize>),
14811490
}
14821491

14831492
/// Binding information for one of the overloads of a callable.
@@ -1829,7 +1838,7 @@ impl<'db> Binding<'db> {
18291838
inherited_specialization: self.inherited_specialization,
18301839
argument_parameters: self.argument_parameters.clone(),
18311840
parameter_tys: self.parameter_tys.clone(),
1832-
errors_position: self.errors.len(),
1841+
errors: self.errors.clone(),
18331842
}
18341843
}
18351844

@@ -1840,25 +1849,121 @@ impl<'db> Binding<'db> {
18401849
inherited_specialization,
18411850
argument_parameters,
18421851
parameter_tys,
1843-
errors_position,
1852+
errors,
18441853
} = snapshot;
18451854

18461855
self.return_ty = return_ty;
18471856
self.specialization = specialization;
18481857
self.inherited_specialization = inherited_specialization;
18491858
self.argument_parameters = argument_parameters;
18501859
self.parameter_tys = parameter_tys;
1851-
self.errors.truncate(errors_position);
1860+
self.errors = errors;
18521861
}
18531862
}
18541863

1864+
#[derive(Clone, Debug)]
18551865
struct BindingSnapshot<'db> {
18561866
return_ty: Type<'db>,
18571867
specialization: Option<Specialization<'db>>,
18581868
inherited_specialization: Option<Specialization<'db>>,
18591869
argument_parameters: Box<[Option<usize>]>,
18601870
parameter_tys: Box<[Option<Type<'db>>]>,
1861-
errors_position: usize,
1871+
errors: Vec<BindingError<'db>>,
1872+
}
1873+
1874+
/// Represents the snapshot of the matched overload bindings.
1875+
///
1876+
/// The reason that this only contains the matched overloads are:
1877+
/// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
1878+
/// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
1879+
/// expanded argument lists
1880+
#[derive(Clone, Debug)]
1881+
struct MatchingOverloadsSnapshot<'db>(Vec<(usize, BindingSnapshot<'db>)>);
1882+
1883+
impl<'db> MatchingOverloadsSnapshot<'db> {
1884+
fn update(&mut self, db: &'db dyn Db, binding: &CallableBinding<'db>) {
1885+
fn combine_specializations<'db>(
1886+
db: &'db dyn Db,
1887+
s1: Option<Specialization<'db>>,
1888+
s2: Option<Specialization<'db>>,
1889+
) -> Option<Specialization<'db>> {
1890+
match (s1, s2) {
1891+
(None, None) => None,
1892+
(Some(s), None) | (None, Some(s)) => Some(s),
1893+
(Some(s1), Some(s2)) => Some(s1.combine(db, s2)),
1894+
}
1895+
}
1896+
1897+
for (snapshot, binding) in self
1898+
.0
1899+
.iter_mut()
1900+
.map(|(index, snapshot)| (snapshot, &binding.overloads[*index]))
1901+
{
1902+
snapshot.return_ty = binding.return_ty;
1903+
snapshot.specialization =
1904+
combine_specializations(db, snapshot.specialization, binding.specialization);
1905+
snapshot.inherited_specialization = combine_specializations(
1906+
db,
1907+
snapshot.inherited_specialization,
1908+
binding.inherited_specialization,
1909+
);
1910+
snapshot
1911+
.argument_parameters
1912+
.clone_from(&binding.argument_parameters);
1913+
snapshot.parameter_tys.clone_from(&binding.parameter_tys);
1914+
1915+
if binding.errors.is_empty() {
1916+
// If the binding has no errors, this means that the current argument list
1917+
// was evaluated successfully and this is the matched overload. Clear the
1918+
// errors from the snapshot of this overload to signal this change.
1919+
snapshot.errors.clear();
1920+
} else if !snapshot.errors.is_empty() {
1921+
// If the errors in the snapshot was empty, then this binding is the
1922+
// matched overload for a previously evaluated argument list.
1923+
//
1924+
// If it does have errors, we just extend it with the errors from
1925+
// evaluating the current argument list.
1926+
snapshot.errors.extend_from_slice(&binding.errors);
1927+
}
1928+
}
1929+
}
1930+
}
1931+
1932+
/// A helper to take snapshots of the matched overload bindings for the current state of the
1933+
/// bindings.
1934+
struct Snapshotter(Vec<usize>);
1935+
1936+
impl Snapshotter {
1937+
fn new(indexes: Vec<usize>) -> Self {
1938+
debug_assert!(indexes.len() > 1);
1939+
Snapshotter(indexes)
1940+
}
1941+
1942+
/// Takes a snapshot of the current state of the matched overload bindings.
1943+
///
1944+
/// # Panics
1945+
///
1946+
/// Panics if the indexes of the matched overloads are not valid for the given binding.
1947+
fn take<'db>(&self, binding: &CallableBinding<'db>) -> MatchingOverloadsSnapshot<'db> {
1948+
MatchingOverloadsSnapshot(
1949+
self.0
1950+
.iter()
1951+
.map(|index| (*index, binding.overloads[*index].snapshot()))
1952+
.collect(),
1953+
)
1954+
}
1955+
1956+
/// Restores the state of the matched overload bindings from the given snapshot.
1957+
fn restore<'db>(
1958+
&self,
1959+
binding: &mut CallableBinding<'db>,
1960+
snapshot: MatchingOverloadsSnapshot<'db>,
1961+
) {
1962+
debug_assert_eq!(self.0.len(), snapshot.0.len());
1963+
for (index, snapshot) in snapshot.0 {
1964+
binding.overloads[index].restore(snapshot);
1965+
}
1966+
}
18621967
}
18631968

18641969
/// Describes a callable for the purposes of diagnostics.

0 commit comments

Comments
 (0)