@@ -2490,20 +2490,23 @@ struct ArgumentTypeChecker<'a, 'db> {
24902490 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
24912491 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
24922492 call_expression_tcx : & ' a TypeContext < ' db > ,
2493+ return_ty : Type < ' db > ,
24932494 errors : & ' a mut Vec < BindingError < ' db > > ,
24942495
24952496 inferable_typevars : InferableTypeVars < ' db , ' db > ,
24962497 specialization : Option < Specialization < ' db > > ,
24972498}
24982499
24992500impl < ' a , ' db > ArgumentTypeChecker < ' a , ' db > {
2501+ #[ expect( clippy:: too_many_arguments) ]
25002502 fn new (
25012503 db : & ' db dyn Db ,
25022504 signature : & ' a Signature < ' db > ,
25032505 arguments : & ' a CallArguments < ' a , ' db > ,
25042506 argument_matches : & ' a [ MatchedArgument < ' db > ] ,
25052507 parameter_tys : & ' a mut [ Option < Type < ' db > > ] ,
25062508 call_expression_tcx : & ' a TypeContext < ' db > ,
2509+ return_ty : Type < ' db > ,
25072510 errors : & ' a mut Vec < BindingError < ' db > > ,
25082511 ) -> Self {
25092512 Self {
@@ -2513,6 +2516,7 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25132516 argument_matches,
25142517 parameter_tys,
25152518 call_expression_tcx,
2519+ return_ty,
25162520 errors,
25172521 inferable_typevars : InferableTypeVars :: None ,
25182522 specialization : None ,
@@ -2554,25 +2558,6 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25542558 // TODO: Use the list of inferable typevars from the generic context of the callable.
25552559 let mut builder = SpecializationBuilder :: new ( self . db , self . inferable_typevars ) ;
25562560
2557- // Note that we infer the annotated type _before_ the arguments if this call is part of
2558- // an annotated assignment, to closer match the order of any unions written in the type
2559- // annotation.
2560- if let Some ( return_ty) = self . signature . return_ty
2561- && let Some ( call_expression_tcx) = self . call_expression_tcx . annotation
2562- {
2563- match call_expression_tcx {
2564- // A type variable is not a useful type-context for expression inference, and applying it
2565- // to the return type can lead to confusing unions in nested generic calls.
2566- Type :: TypeVar ( _) => { }
2567-
2568- _ => {
2569- // Ignore any specialization errors here, because the type context is only used as a hint
2570- // to infer a more assignable return type.
2571- let _ = builder. infer ( return_ty, call_expression_tcx) ;
2572- }
2573- }
2574- }
2575-
25762561 let parameters = self . signature . parameters ( ) ;
25772562 for ( argument_index, adjusted_argument_index, _, argument_type) in
25782563 self . enumerate_argument_types ( )
@@ -2597,7 +2582,41 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
25972582 }
25982583 }
25992584
2600- self . specialization = Some ( builder. build ( generic_context, * self . call_expression_tcx ) ) ;
2585+ // Build the specialization first without inferring the type context.
2586+ let isolated_specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
2587+ let isolated_return_ty = self
2588+ . return_ty
2589+ . apply_specialization ( self . db , isolated_specialization) ;
2590+
2591+ let mut try_infer_tcx = || {
2592+ let return_ty = self . signature . return_ty ?;
2593+ let call_expression_tcx = self . call_expression_tcx . annotation ?;
2594+
2595+ // A type variable is not a useful type-context for expression inference, and applying it
2596+ // to the return type can lead to confusing unions in nested generic calls.
2597+ if call_expression_tcx. is_type_var ( ) {
2598+ return None ;
2599+ }
2600+
2601+ // If the return type is already assignable to the annotated type, we can ignore the
2602+ // type context and prefer the narrower inferred type.
2603+ if isolated_return_ty. is_assignable_to ( self . db , call_expression_tcx) {
2604+ return None ;
2605+ }
2606+
2607+ // TODO: Ideally we would infer the annotated type _before_ the arguments if this call is part of an
2608+ // annotated assignment, to closer match the order of any unions written in the type annotation.
2609+ builder. infer ( return_ty, call_expression_tcx) . ok ( ) ?;
2610+
2611+ // Otherwise, build the specialization again after inferring the type context.
2612+ let specialization = builder. build ( generic_context, * self . call_expression_tcx ) ;
2613+ let return_ty = return_ty. apply_specialization ( self . db , specialization) ;
2614+
2615+ Some ( ( Some ( specialization) , return_ty) )
2616+ } ;
2617+
2618+ ( self . specialization , self . return_ty ) =
2619+ try_infer_tcx ( ) . unwrap_or ( ( Some ( isolated_specialization) , isolated_return_ty) ) ;
26012620 }
26022621
26032622 fn check_argument_type (
@@ -2792,8 +2811,14 @@ impl<'a, 'db> ArgumentTypeChecker<'a, 'db> {
27922811 }
27932812 }
27942813
2795- fn finish ( self ) -> ( InferableTypeVars < ' db , ' db > , Option < Specialization < ' db > > ) {
2796- ( self . inferable_typevars , self . specialization )
2814+ fn finish (
2815+ self ,
2816+ ) -> (
2817+ InferableTypeVars < ' db , ' db > ,
2818+ Option < Specialization < ' db > > ,
2819+ Type < ' db > ,
2820+ ) {
2821+ ( self . inferable_typevars , self . specialization , self . return_ty )
27972822 }
27982823}
27992824
@@ -2950,18 +2975,16 @@ impl<'db> Binding<'db> {
29502975 & self . argument_matches ,
29512976 & mut self . parameter_tys ,
29522977 call_expression_tcx,
2978+ self . return_ty ,
29532979 & mut self . errors ,
29542980 ) ;
29552981
29562982 // If this overload is generic, first see if we can infer a specialization of the function
29572983 // from the arguments that were passed in.
29582984 checker. infer_specialization ( ) ;
2959-
29602985 checker. check_argument_types ( ) ;
2961- ( self . inferable_typevars , self . specialization ) = checker. finish ( ) ;
2962- if let Some ( specialization) = self . specialization {
2963- self . return_ty = self . return_ty . apply_specialization ( db, specialization) ;
2964- }
2986+
2987+ ( self . inferable_typevars , self . specialization , self . return_ty ) = checker. finish ( ) ;
29652988 }
29662989
29672990 pub ( crate ) fn set_return_type ( & mut self , return_ty : Type < ' db > ) {
0 commit comments