@@ -23,8 +23,8 @@ use crate::types::signatures::{Parameter, ParameterForm};
2323use crate :: types:: {
2424 BoundMethodType , DataclassParams , DataclassTransformerParams , FunctionDecorators , FunctionType ,
2525 KnownClass , KnownFunction , KnownInstanceType , MethodWrapperKind , PropertyInstanceType ,
26- SpecialFormType , TupleType , TypeMapping , UnionBuilder , UnionType , WrapperDescriptorKind ,
27- ide_support , todo_type,
26+ SpecialFormType , TupleType , TypeMapping , UnionType , WrapperDescriptorKind , ide_support ,
27+ todo_type,
2828} ;
2929use ruff_db:: diagnostic:: { Annotation , Diagnostic , Severity , SubDiagnostic } ;
3030use ruff_python_ast as ast;
@@ -1105,32 +1105,12 @@ impl<'db> CallableBinding<'db> {
11051105 }
11061106
11071107 fn check_types ( & mut self , db : & ' db dyn Db , argument_types : & CallArgumentTypes < ' _ , ' db > ) {
1108- /// Represents the snapshot of the overload bindings.
1109- struct OverloadSnapshots < ' db > ( SmallVec < [ BindingSnapshot < ' db > ; 1 ] > ) ;
1110-
1111- /// Takes a snapshot of the current state of the overload bindings.
1112- fn snapshot < ' db > ( binding : & CallableBinding < ' db > ) -> OverloadSnapshots < ' db > {
1113- OverloadSnapshots ( binding. overloads . iter ( ) . map ( Binding :: snapshot) . collect ( ) )
1114- }
1115-
1116- /// Restores the state of the overload bindings from the given snapshots.
1117- fn restore < ' db > ( binding : & mut CallableBinding < ' db > , snapshots : OverloadSnapshots < ' db > ) {
1118- debug_assert_eq ! ( binding. overloads. len( ) , snapshots. 0 . len( ) ) ;
1119- binding
1120- . overloads
1121- . iter_mut ( )
1122- . zip ( snapshots. 0 )
1123- . for_each ( |( binding, snapshot) | {
1124- binding. restore ( snapshot) ;
1125- } ) ;
1126- }
1127-
11281108 // If this callable is a bound method, prepend the self instance onto the arguments list
11291109 // before checking.
11301110 let argument_types = argument_types. with_self ( self . bound_type ) ;
11311111
11321112 // Step 1: Check the result of the arity check which is done by `match_parameters`
1133- match self . matching_overload_index ( ) {
1113+ let matching_overload_indexes = match self . matching_overload_index ( ) {
11341114 MatchingOverloadIndex :: None => {
11351115 // If no candidate overloads remain from the arity check, we can stop here. We
11361116 // still perform type checking for non-overloaded function to provide better user
@@ -1142,19 +1122,27 @@ impl<'db> CallableBinding<'db> {
11421122 }
11431123 MatchingOverloadIndex :: Single ( index) => {
11441124 // If only one candidate overload remains, it is the winning match.
1125+ // TODO: Evaluate it as a regular (non-overloaded) call. This means that any
1126+ // diagnostics reported in this check should be reported directly instead of
1127+ // reporting it as `no-matching-overload`.
11451128 self . overloads [ index] . check_types (
11461129 db,
11471130 argument_types. as_ref ( ) ,
11481131 argument_types. types ( ) ,
11491132 ) ;
11501133 return ;
11511134 }
1152- MatchingOverloadIndex :: Multiple ( _ ) => {
1135+ MatchingOverloadIndex :: Multiple ( indexes ) => {
11531136 // If two or more candidate overloads remain, proceed to step 2.
1137+ indexes
11541138 }
1155- }
1139+ } ;
11561140
1157- let pre_evaluation_snapshot = snapshot ( self ) ;
1141+ let snapshotter = Snapshotter :: new ( matching_overload_indexes) ;
1142+
1143+ // State of the bindings _before_ evaluating (type checking) the matching overloads using
1144+ // the non-expanded argument types.
1145+ let pre_evaluation_snapshot = snapshotter. take ( self ) ;
11581146
11591147 // Step 2: Evaluate each remaining overload as a regular (non-overloaded) call to determine
11601148 // whether it is compatible with the supplied argument list.
@@ -1180,23 +1168,31 @@ impl<'db> CallableBinding<'db> {
11801168
11811169 // Step 3: Perform "argument type expansion". Reference:
11821170 // https://typing.python.org/en/latest/spec/overload.html#argument-type-expansion
1183-
11841171 let mut expansions = argument_types. expand ( db) . peekable ( ) ;
11851172
11861173 if expansions. peek ( ) . is_none ( ) {
1187- // Return early if there are no argument types to expand. This is especially useful to
1188- // avoid restoring the bindings state.
1174+ // Return early if there are no argument types to expand.
11891175 return ;
11901176 }
11911177
1192- restore ( self , pre_evaluation_snapshot) ;
1178+ // State of the bindings _after_ evaluating (type checking) the matching overloads using
1179+ // the non-expanded argument types.
1180+ let post_evaluation_snapshot = snapshotter. take ( self ) ;
1181+
1182+ // Restore the bindings state to the one prior to the type checking step in preparation
1183+ // for evaluating the expanded argument lists.
1184+ snapshotter. restore ( self , pre_evaluation_snapshot) ;
11931185
11941186 for expanded_argument_lists in expansions {
1195- let mut all_evaluated_successfully = true ;
1196- let mut union_builder = UnionBuilder :: new ( db) ;
1187+ // This is the merged state of the bindings after evaluating all of the expanded
1188+ // argument lists. This will be the final state to restore the bindings to if all of
1189+ // the expanded argument lists evaluated successfully.
1190+ let mut merged_evaluation_state: Option < MatchingOverloadsSnapshot < ' db > > = None ;
1191+
1192+ let mut return_types = Vec :: new ( ) ;
11971193
11981194 for expanded_argument_types in & expanded_argument_lists {
1199- let pre_evaluation_snapshot = snapshot ( self ) ;
1195+ let pre_evaluation_snapshot = snapshotter . take ( self ) ;
12001196
12011197 for ( _, overload) in self . matching_overloads_mut ( ) {
12021198 overload. check_types ( db, argument_types. as_ref ( ) , expanded_argument_types) ;
@@ -1210,30 +1206,45 @@ impl<'db> CallableBinding<'db> {
12101206 MatchingOverloadIndex :: Multiple ( index) => {
12111207 // TODO: Step 4 and Step 5 goes here... but for now we just use the return
12121208 // type of the first matched overload.
1213- Some ( self . overloads [ index] . return_type ( ) )
1209+ Some ( self . overloads [ index[ 0 ] ] . return_type ( ) )
12141210 }
12151211 } ;
12161212
1213+ if let Some ( merged_evaluation_state) = merged_evaluation_state. as_mut ( ) {
1214+ // Update the evaluation state with the state of the bindings after evaluating
1215+ // the current argument list.
1216+ merged_evaluation_state. update ( db, self ) ;
1217+ } else {
1218+ // Initialize the merged evaluation state with the state of the bindings after
1219+ // evaluating the _first_ argument list. It only contains the snapshots for the
1220+ // matching overloads from the pre-evaluation snapshot.
1221+ merged_evaluation_state = Some ( snapshotter. take ( self ) ) ;
1222+ }
1223+
12171224 // Restore the bindings state before evaluating the next argument list.
1218- restore ( self , pre_evaluation_snapshot) ;
1225+ snapshotter . restore ( self , pre_evaluation_snapshot) ;
12191226
12201227 if let Some ( return_type) = return_type {
1221- union_builder . add_in_place ( return_type) ;
1228+ return_types . push ( return_type) ;
12221229 } else {
12231230 // No need to check the remaining argument lists if the current argument list
12241231 // doesn't evaluate successfully. Move on to expanding the next argument type.
1225- all_evaluated_successfully = false ;
12261232 break ;
12271233 }
12281234 }
12291235
1230- if all_evaluated_successfully {
1236+ if return_types . len ( ) == expanded_argument_lists . len ( ) {
12311237 // If the number of return types is equal to the number of expanded argument lists,
12321238 // they all evaluated successfully. So, we need to combine their return types by
12331239 // union to determine the final return type.
1234- //
1235- // TODO: What should be the state of the bindings at this point?
1236- self . return_type = Some ( union_builder. build ( ) ) ;
1240+ self . return_type = Some ( UnionType :: from_elements ( db, return_types) ) ;
1241+
1242+ // Restore the bindings state to the one that merges the bindings state evaluating
1243+ // each of the expanded argument list.
1244+ if let Some ( merged_evaluation_state) = merged_evaluation_state {
1245+ snapshotter. restore ( self , merged_evaluation_state) ;
1246+ }
1247+
12371248 return ;
12381249 }
12391250 }
@@ -1245,9 +1256,7 @@ impl<'db> CallableBinding<'db> {
12451256 //
12461257 // This will be skipped if there are no argument types to expand because of the early
12471258 // return.
1248- for ( _, overload) in self . matching_overloads_mut ( ) {
1249- overload. check_types ( db, argument_types. as_ref ( ) , argument_types. types ( ) ) ;
1250- }
1259+ snapshotter. restore ( self , post_evaluation_snapshot) ;
12511260 }
12521261
12531262 fn as_result ( & self ) -> Result < ( ) , CallErrorKind > {
@@ -1281,11 +1290,15 @@ impl<'db> CallableBinding<'db> {
12811290 let mut matching_overloads = self . matching_overloads ( ) ;
12821291 match matching_overloads. next ( ) {
12831292 None => MatchingOverloadIndex :: None ,
1284- Some ( ( index, _) ) => {
1285- if matching_overloads. next ( ) . is_some ( ) {
1286- MatchingOverloadIndex :: Multiple ( index)
1293+ Some ( ( first, _) ) => {
1294+ if let Some ( ( second, _) ) = matching_overloads. next ( ) {
1295+ let mut indexes = vec ! [ first, second] ;
1296+ for ( index, _) in matching_overloads {
1297+ indexes. push ( index) ;
1298+ }
1299+ MatchingOverloadIndex :: Multiple ( indexes)
12871300 } else {
1288- MatchingOverloadIndex :: Single ( index )
1301+ MatchingOverloadIndex :: Single ( first )
12891302 }
12901303 }
12911304 }
@@ -1469,15 +1482,11 @@ enum MatchingOverloadIndex {
14691482 /// No matching overloads found.
14701483 None ,
14711484
1472- /// Exactly one matching overload found.
1473- ///
1474- /// The index is the position of the matching overload.
1485+ /// Exactly one matching overload found at the given index.
14751486 Single ( usize ) ,
14761487
1477- /// Multiple matching overloads found.
1478- ///
1479- /// The index is the position of the first matching overload.
1480- Multiple ( usize ) ,
1488+ /// Multiple matching overloads found at the given indexes.
1489+ Multiple ( Vec < usize > ) ,
14811490}
14821491
14831492/// Binding information for one of the overloads of a callable.
@@ -1829,7 +1838,7 @@ impl<'db> Binding<'db> {
18291838 inherited_specialization : self . inherited_specialization ,
18301839 argument_parameters : self . argument_parameters . clone ( ) ,
18311840 parameter_tys : self . parameter_tys . clone ( ) ,
1832- errors_position : self . errors . len ( ) ,
1841+ errors : self . errors . clone ( ) ,
18331842 }
18341843 }
18351844
@@ -1840,25 +1849,121 @@ impl<'db> Binding<'db> {
18401849 inherited_specialization,
18411850 argument_parameters,
18421851 parameter_tys,
1843- errors_position ,
1852+ errors ,
18441853 } = snapshot;
18451854
18461855 self . return_ty = return_ty;
18471856 self . specialization = specialization;
18481857 self . inherited_specialization = inherited_specialization;
18491858 self . argument_parameters = argument_parameters;
18501859 self . parameter_tys = parameter_tys;
1851- self . errors . truncate ( errors_position ) ;
1860+ self . errors = errors ;
18521861 }
18531862}
18541863
1864+ #[ derive( Clone , Debug ) ]
18551865struct BindingSnapshot < ' db > {
18561866 return_ty : Type < ' db > ,
18571867 specialization : Option < Specialization < ' db > > ,
18581868 inherited_specialization : Option < Specialization < ' db > > ,
18591869 argument_parameters : Box < [ Option < usize > ] > ,
18601870 parameter_tys : Box < [ Option < Type < ' db > > ] > ,
1861- errors_position : usize ,
1871+ errors : Vec < BindingError < ' db > > ,
1872+ }
1873+
1874+ /// Represents the snapshot of the matched overload bindings.
1875+ ///
1876+ /// The reason that this only contains the matched overloads are:
1877+ /// 1. Avoid creating snapshots for the overloads that have been filtered by the arity check
1878+ /// 2. Avoid duplicating errors when merging the snapshots on a successful evaluation of all the
1879+ /// expanded argument lists
1880+ #[ derive( Clone , Debug ) ]
1881+ struct MatchingOverloadsSnapshot < ' db > ( Vec < ( usize , BindingSnapshot < ' db > ) > ) ;
1882+
1883+ impl < ' db > MatchingOverloadsSnapshot < ' db > {
1884+ fn update ( & mut self , db : & ' db dyn Db , binding : & CallableBinding < ' db > ) {
1885+ fn combine_specializations < ' db > (
1886+ db : & ' db dyn Db ,
1887+ s1 : Option < Specialization < ' db > > ,
1888+ s2 : Option < Specialization < ' db > > ,
1889+ ) -> Option < Specialization < ' db > > {
1890+ match ( s1, s2) {
1891+ ( None , None ) => None ,
1892+ ( Some ( s) , None ) | ( None , Some ( s) ) => Some ( s) ,
1893+ ( Some ( s1) , Some ( s2) ) => Some ( s1. combine ( db, s2) ) ,
1894+ }
1895+ }
1896+
1897+ for ( snapshot, binding) in self
1898+ . 0
1899+ . iter_mut ( )
1900+ . map ( |( index, snapshot) | ( snapshot, & binding. overloads [ * index] ) )
1901+ {
1902+ snapshot. return_ty = binding. return_ty ;
1903+ snapshot. specialization =
1904+ combine_specializations ( db, snapshot. specialization , binding. specialization ) ;
1905+ snapshot. inherited_specialization = combine_specializations (
1906+ db,
1907+ snapshot. inherited_specialization ,
1908+ binding. inherited_specialization ,
1909+ ) ;
1910+ snapshot
1911+ . argument_parameters
1912+ . clone_from ( & binding. argument_parameters ) ;
1913+ snapshot. parameter_tys . clone_from ( & binding. parameter_tys ) ;
1914+
1915+ if binding. errors . is_empty ( ) {
1916+ // If the binding has no errors, this means that the current argument list
1917+ // was evaluated successfully and this is the matched overload. Clear the
1918+ // errors from the snapshot of this overload to signal this change.
1919+ snapshot. errors . clear ( ) ;
1920+ } else if !snapshot. errors . is_empty ( ) {
1921+ // If the errors in the snapshot was empty, then this binding is the
1922+ // matched overload for a previously evaluated argument list.
1923+ //
1924+ // If it does have errors, we just extend it with the errors from
1925+ // evaluating the current argument list.
1926+ snapshot. errors . extend_from_slice ( & binding. errors ) ;
1927+ }
1928+ }
1929+ }
1930+ }
1931+
1932+ /// A helper to take snapshots of the matched overload bindings for the current state of the
1933+ /// bindings.
1934+ struct Snapshotter ( Vec < usize > ) ;
1935+
1936+ impl Snapshotter {
1937+ fn new ( indexes : Vec < usize > ) -> Self {
1938+ debug_assert ! ( indexes. len( ) > 1 ) ;
1939+ Snapshotter ( indexes)
1940+ }
1941+
1942+ /// Takes a snapshot of the current state of the matched overload bindings.
1943+ ///
1944+ /// # Panics
1945+ ///
1946+ /// Panics if the indexes of the matched overloads are not valid for the given binding.
1947+ fn take < ' db > ( & self , binding : & CallableBinding < ' db > ) -> MatchingOverloadsSnapshot < ' db > {
1948+ MatchingOverloadsSnapshot (
1949+ self . 0
1950+ . iter ( )
1951+ . map ( |index| ( * index, binding. overloads [ * index] . snapshot ( ) ) )
1952+ . collect ( ) ,
1953+ )
1954+ }
1955+
1956+ /// Restores the state of the matched overload bindings from the given snapshot.
1957+ fn restore < ' db > (
1958+ & self ,
1959+ binding : & mut CallableBinding < ' db > ,
1960+ snapshot : MatchingOverloadsSnapshot < ' db > ,
1961+ ) {
1962+ debug_assert_eq ! ( self . 0 . len( ) , snapshot. 0 . len( ) ) ;
1963+ for ( index, snapshot) in snapshot. 0 {
1964+ binding. overloads [ index] . restore ( snapshot) ;
1965+ }
1966+ }
18621967}
18631968
18641969/// Describes a callable for the purposes of diagnostics.
0 commit comments