@@ -56,15 +56,16 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
5656 } ;
5757
5858 let type_ref = & ret_type. ty ( ) ?;
59- let ty = ctx. sema . resolve_type ( type_ref) ?. as_adt ( ) ;
59+ let ty = ctx. sema . resolve_type ( type_ref) ?;
60+ let ty_adt = ty. as_adt ( ) ;
6061 let famous_defs = FamousDefs ( & ctx. sema , ctx. sema . scope ( type_ref. syntax ( ) ) ?. krate ( ) ) ;
6162
6263 for kind in WrapperKind :: ALL {
6364 let Some ( core_wrapper) = kind. core_type ( & famous_defs) else {
6465 continue ;
6566 } ;
6667
67- if matches ! ( ty , Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_wrapper) {
68+ if matches ! ( ty_adt , Some ( hir:: Adt :: Enum ( ret_type) ) if ret_type == core_wrapper) {
6869 // The return type is already wrapped
6970 cov_mark:: hit!( wrap_return_type_simple_return_type_already_wrapped) ;
7071 continue ;
@@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
7879 |builder| {
7980 let mut editor = builder. make_editor ( & parent) ;
8081 let make = SyntaxFactory :: with_mappings ( ) ;
81- let alias = wrapper_alias ( ctx, & make, & core_wrapper, type_ref, kind. symbol ( ) ) ;
82- let new_return_ty = alias. unwrap_or_else ( || match kind {
83- WrapperKind :: Option => make. ty_option ( type_ref. clone ( ) ) ,
84- WrapperKind :: Result => make. ty_result ( type_ref. clone ( ) , make. ty_infer ( ) . into ( ) ) ,
82+ let alias = wrapper_alias ( ctx, & make, core_wrapper, type_ref, & ty, kind. symbol ( ) ) ;
83+ let ( ast_new_return_ty, semantic_new_return_ty) = alias. unwrap_or_else ( || {
84+ let ( ast_ty, ty_constructor) = match kind {
85+ WrapperKind :: Option => {
86+ ( make. ty_option ( type_ref. clone ( ) ) , famous_defs. core_option_Option ( ) )
87+ }
88+ WrapperKind :: Result => (
89+ make. ty_result ( type_ref. clone ( ) , make. ty_infer ( ) . into ( ) ) ,
90+ famous_defs. core_result_Result ( ) ,
91+ ) ,
92+ } ;
93+ let semantic_ty = ty_constructor
94+ . map ( |ty_constructor| {
95+ hir:: Adt :: from ( ty_constructor) . ty_with_args ( ctx. db ( ) , [ ty. clone ( ) ] )
96+ } )
97+ . unwrap_or_else ( || ty. clone ( ) ) ;
98+ ( ast_ty, semantic_ty)
8599 } ) ;
86100
87101 let mut exprs_to_wrap = Vec :: new ( ) ;
@@ -96,19 +110,30 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
96110 for_each_tail_expr ( & body_expr, tail_cb) ;
97111
98112 for ret_expr_arg in exprs_to_wrap {
113+ if let Some ( ty) = ctx. sema . type_of_expr ( & ret_expr_arg) {
114+ if ty. adjusted ( ) . could_unify_with ( ctx. db ( ) , & semantic_new_return_ty) {
115+ // The type is already correct, don't wrap it.
116+ // We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
117+ // enum matches it's okay for us, as we don't trigger the assist if the return type
118+ // is already `Option`/`Result`, so mismatched exact type is more likely a mistake
119+ // than something intended.
120+ continue ;
121+ }
122+ }
123+
99124 let happy_wrapped = make. expr_call (
100125 make. expr_path ( make. ident_path ( kind. happy_ident ( ) ) ) ,
101126 make. arg_list ( iter:: once ( ret_expr_arg. clone ( ) ) ) ,
102127 ) ;
103128 editor. replace ( ret_expr_arg. syntax ( ) , happy_wrapped. syntax ( ) ) ;
104129 }
105130
106- editor. replace ( type_ref. syntax ( ) , new_return_ty . syntax ( ) ) ;
131+ editor. replace ( type_ref. syntax ( ) , ast_new_return_ty . syntax ( ) ) ;
107132
108133 if let WrapperKind :: Result = kind {
109134 // Add a placeholder snippet at the first generic argument that doesn't equal the return type.
110135 // This is normally the error type, but that may not be the case when we inserted a type alias.
111- let args = new_return_ty
136+ let args = ast_new_return_ty
112137 . path ( )
113138 . unwrap ( )
114139 . segment ( )
@@ -188,35 +213,36 @@ impl WrapperKind {
188213}
189214
190215// Try to find an wrapper type alias in the current scope (shadowing the default).
191- fn wrapper_alias (
192- ctx : & AssistContext < ' _ > ,
216+ fn wrapper_alias < ' db > (
217+ ctx : & AssistContext < ' db > ,
193218 make : & SyntaxFactory ,
194- core_wrapper : & hir:: Enum ,
195- ret_type : & ast:: Type ,
219+ core_wrapper : hir:: Enum ,
220+ ast_ret_type : & ast:: Type ,
221+ semantic_ret_type : & hir:: Type < ' db > ,
196222 wrapper : hir:: Symbol ,
197- ) -> Option < ast:: PathType > {
223+ ) -> Option < ( ast:: PathType , hir :: Type < ' db > ) > {
198224 let wrapper_path = hir:: ModPath :: from_segments (
199225 hir:: PathKind :: Plain ,
200226 iter:: once ( hir:: Name :: new_symbol_root ( wrapper) ) ,
201227 ) ;
202228
203- ctx. sema . resolve_mod_path ( ret_type . syntax ( ) , & wrapper_path) . and_then ( |def| {
229+ ctx. sema . resolve_mod_path ( ast_ret_type . syntax ( ) , & wrapper_path) . and_then ( |def| {
204230 def. filter_map ( |def| match def. into_module_def ( ) {
205231 hir:: ModuleDef :: TypeAlias ( alias) => {
206232 let enum_ty = alias. ty ( ctx. db ( ) ) . as_adt ( ) ?. as_enum ( ) ?;
207- ( & enum_ty == core_wrapper) . then_some ( alias)
233+ ( enum_ty == core_wrapper) . then_some ( ( alias, enum_ty ) )
208234 }
209235 _ => None ,
210236 } )
211- . find_map ( |alias| {
237+ . find_map ( |( alias, enum_ty ) | {
212238 let mut inserted_ret_type = false ;
213239 let generic_args =
214240 alias. source ( ctx. db ( ) ) ?. value . generic_param_list ( ) ?. generic_params ( ) . map ( |param| {
215241 match param {
216242 // Replace the very first type parameter with the function's return type.
217243 ast:: GenericParam :: TypeParam ( _) if !inserted_ret_type => {
218244 inserted_ret_type = true ;
219- make. type_arg ( ret_type . clone ( ) ) . into ( )
245+ make. type_arg ( ast_ret_type . clone ( ) ) . into ( )
220246 }
221247 ast:: GenericParam :: LifetimeParam ( _) => {
222248 make. lifetime_arg ( make. lifetime ( "'_" ) ) . into ( )
@@ -231,7 +257,10 @@ fn wrapper_alias(
231257 make. path_segment_generics ( make. name_ref ( name. as_str ( ) ) , generic_arg_list) ,
232258 ) ;
233259
234- Some ( make. ty_path ( path) )
260+ let new_ty =
261+ hir:: Adt :: from ( enum_ty) . ty_with_args ( ctx. db ( ) , [ semantic_ret_type. clone ( ) ] ) ;
262+
263+ Some ( ( make. ty_path ( path) , new_ty) )
235264 } )
236265 } )
237266}
@@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
605634 check_assist_by_label (
606635 wrap_return_type,
607636 r#"
608- //- minicore: option
637+ //- minicore: option, future
638+ struct F(i32);
639+ impl core::future::Future for F {
640+ type Output = i32;
641+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
642+ }
609643async fn foo() -> i$032 {
610644 if true {
611645 if false {
612- 1 .await
646+ F(1) .await
613647 } else {
614- 2 .await
648+ F(2) .await
615649 }
616650 } else {
617- 24i32.await
651+ F( 24i32) .await
618652 }
619653}
620654"# ,
621655 r#"
656+ struct F(i32);
657+ impl core::future::Future for F {
658+ type Output = i32;
659+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
660+ }
622661async fn foo() -> Option<i32> {
623662 if true {
624663 if false {
625- Some(1 .await)
664+ Some(F(1) .await)
626665 } else {
627- Some(2 .await)
666+ Some(F(2) .await)
628667 }
629668 } else {
630- Some(24i32.await)
669+ Some(F( 24i32) .await)
631670 }
632671}
633672"# ,
@@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
16661705 check_assist_by_label (
16671706 wrap_return_type,
16681707 r#"
1669- //- minicore: result
1708+ //- minicore: result, future
1709+ struct F(i32);
1710+ impl core::future::Future for F {
1711+ type Output = i32;
1712+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1713+ }
16701714async fn foo() -> i$032 {
16711715 if true {
16721716 if false {
1673- 1 .await
1717+ F(1) .await
16741718 } else {
1675- 2 .await
1719+ F(2) .await
16761720 }
16771721 } else {
1678- 24i32.await
1722+ F( 24i32) .await
16791723 }
16801724}
16811725"# ,
16821726 r#"
1727+ struct F(i32);
1728+ impl core::future::Future for F {
1729+ type Output = i32;
1730+ fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
1731+ }
16831732async fn foo() -> Result<i32, ${0:_}> {
16841733 if true {
16851734 if false {
1686- Ok(1 .await)
1735+ Ok(F(1) .await)
16871736 } else {
1688- Ok(2 .await)
1737+ Ok(F(2) .await)
16891738 }
16901739 } else {
1691- Ok(24i32.await)
1740+ Ok(F( 24i32) .await)
16921741 }
16931742}
16941743"# ,
@@ -2455,6 +2504,56 @@ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;
24552504
24562505fn foo() -> Result<i32, ${0:_}> {
24572506 Ok(0)
2507+ }
2508+ "# ,
2509+ WrapperKind :: Result . label ( ) ,
2510+ ) ;
2511+ }
2512+
2513+ #[ test]
2514+ fn already_wrapped ( ) {
2515+ check_assist_by_label (
2516+ wrap_return_type,
2517+ r#"
2518+ //- minicore: option
2519+ fn foo() -> i32$0 {
2520+ if false {
2521+ 0
2522+ } else {
2523+ Some(1)
2524+ }
2525+ }
2526+ "# ,
2527+ r#"
2528+ fn foo() -> Option<i32> {
2529+ if false {
2530+ Some(0)
2531+ } else {
2532+ Some(1)
2533+ }
2534+ }
2535+ "# ,
2536+ WrapperKind :: Option . label ( ) ,
2537+ ) ;
2538+ check_assist_by_label (
2539+ wrap_return_type,
2540+ r#"
2541+ //- minicore: result
2542+ fn foo() -> i32$0 {
2543+ if false {
2544+ 0
2545+ } else {
2546+ Ok(1)
2547+ }
2548+ }
2549+ "# ,
2550+ r#"
2551+ fn foo() -> Result<i32, ${0:_}> {
2552+ if false {
2553+ Ok(0)
2554+ } else {
2555+ Ok(1)
2556+ }
24582557}
24592558 "# ,
24602559 WrapperKind :: Result . label ( ) ,
0 commit comments