@@ -2524,20 +2524,23 @@ struct ArgumentTypeChecker<'a, 'db> {
25242524 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
25252525 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
25262526 call_expression_tcx : & ' a TypeContext < ' db > ,
2527+ return_ty : Type < ' db > ,
25272528 errors : & ' a mut Vec < BindingError < ' db > > ,
25282529
25292530 inferable_typevars : InferableTypeVars < ' db , ' db > ,
25302531 specialization : Option < Specialization < ' db > > ,
25312532}
25322533
25332534impl < ' a , ' db > ArgumentTypeChecker < ' a , ' db > {
2535+ #[ expect( clippy:: too_many_arguments) ]
25342536 fn new (
25352537 db : & ' db dyn Db ,
25362538 signature : & ' a Signature < ' db > ,
25372539 arguments : & ' a CallArguments < ' a , ' db > ,
25382540 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
25392541 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
25402542 call_expression_tcx : & ' a TypeContext < ' db > ,
2543+ return_ty : Type < ' db > ,
25412544 errors : & ' a mut Vec < BindingError < ' db > > ,
25422545 ) -> Self {
25432546 Self {
@@ -2547,6 +2550,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25472550 argument_matches,
25482551 parameter_tys,
25492552 call_expression_tcx,
2553+ return_ty,
25502554 errors,
25512555 inferable_typevars : InferableTypeVars :: None ,
25522556 specialization : None ,
@@ -2588,25 +2592,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25882592 // TODO: Use the list of inferable typevars from the generic context of the callable.
25892593 let mut builder = SpecializationBuilder :: new ( self . db , self . inferable_typevars ) ;
25902594
2591- // Note that we infer the annotated type _before_ the arguments if this call is part of
2592- // an annotated assignment, to closer match the order of any unions written in the type
2593- // annotation.
2594- if let Some ( return_ty) = self . signature . return_ty
2595- && let Some ( call_expression_tcx) = self . call_expression_tcx . annotation
2596- {
2597- match call_expression_tcx {
2598- // A type variable is not a useful type-context for expression inference, and applying it
2599- // to the return type can lead to confusing unions in nested generic calls.
2600- Type :: TypeVar ( _) => { }
2601-
2602- _ => {
2603- // Ignore any specialization errors here, because the type context is only used as a hint
2604- // to infer a more assignable return type.
2605- let _ = builder. infer ( return_ty, call_expression_tcx) ;
2606- }
2607- }
2608- }
2609-
26102595 let parameters = self . signature . parameters ( ) ;
26112596 for ( argument_index, adjusted_argument_index, _, argument_type) in
26122597 self . enumerate_argument_types ( )
@@ -2631,7 +2616,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
26312616 }
26322617 }
26332618
2634- self . specialization = Some ( builder. build ( generic_context, * self . call_expression_tcx ) ) ;
2619+ // Build the specialization first without inferring the type context.
2620+ let isolated_specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
2621+ let isolated_return_ty = self
2622+ . return_ty
2623+ . apply_specialization ( self . db , isolated_specialization) ;
2624+
2625+ let mut try_infer_tcx = || {
2626+ let return_ty = self . signature . return_ty ?;
2627+ let call_expression_tcx = self . call_expression_tcx . annotation ?;
2628+
2629+ // A type variable is not a useful type-context for expression inference, and applying it
2630+ // to the return type can lead to confusing unions in nested generic calls.
2631+ if call_expression_tcx. is_type_var ( ) {
2632+ return None ;
2633+ }
2634+
2635+ // If the return type is already assignable to the annotated type, we can ignore the
2636+ // type context and prefer the narrower inferred type.
2637+ if isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx) {
2638+ return None ;
2639+ }
2640+
2641+ // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2642+ // annotated assignment, to closer match the order of any unions written in the type annotation.
2643+ builder. infer ( return_ty, call_expression_tcx) . ok ( ) ?;
2644+
2645+ // Otherwise, build the specialization again after inferring the type context.
2646+ let specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
2647+ let return_ty = return_ty. apply_specialization ( self . db , specialization) ;
2648+
2649+ Some ( ( Some ( specialization) , return_ty) )
2650+ } ;
2651+
2652+ ( self . specialization , self . return_ty ) =
2653+ try_infer_tcx ( ) . unwrap_or ( ( Some ( isolated_specialization) , isolated_return_ty) ) ;
26352654 }
26362655
26372656 fn check_argument_type (
@@ -2826,8 +2845,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
28262845 }
28272846 }
28282847
2829- fn finish ( self ) -> ( InferableTypeVars < ' db , ' db > , Option < Specialization < ' db > > ) {
2830- ( self . inferable_typevars , self . specialization )
2848+ fn finish (
2849+ self ,
2850+ ) -> (
2851+ InferableTypeVars < ' db , ' db > ,
2852+ Option < Specialization < ' db > > ,
2853+ Type < ' db > ,
2854+ ) {
2855+ ( self . inferable_typevars , self . specialization , self . return_ty )
28312856 }
28322857}
28332858
@@ -2985,18 +3010,16 @@ impl<'db> Binding<'db> {
29853010 & self . argument_matches ,
29863011 & mut self . parameter_tys ,
29873012 call_expression_tcx,
3013+ self . return_ty ,
29883014 & mut self . errors ,
29893015 ) ;
29903016
29913017 // If this overload is generic, first see if we can infer a specialization of the function
29923018 // from the arguments that were passed in.
29933019 checker. infer_specialization ( ) ;
2994-
29953020 checker. check_argument_types ( ) ;
2996- ( self . inferable_typevars , self . specialization ) = checker. finish ( ) ;
2997- if let Some ( specialization) = self . specialization {
2998- self . return_ty = self . return_ty . apply_specialization ( db, specialization) ;
2999- }
3021+
3022+ ( self . inferable_typevars , self . specialization , self . return_ty ) = checker. finish ( ) ;
30003023 }
30013024
30023025 pub ( crate ) fn set_return_type ( & mut self , return_ty : Type < ' db > ) {
0 commit comments