@@ -718,7 +718,9 @@ impl<'a, 'gcx, 'tcx> InferCtxt<'a, 'gcx, 'tcx> {
718
718
return ;
719
719
}
720
720
let expected_trait_ty = expected_trait_ref. self_ty ( ) ;
721
- let found_span = expected_trait_ty. ty_to_def_id ( ) . and_then ( |did| {
721
+
722
+ let found_did = expected_trait_ty. ty_to_def_id ( ) ;
723
+ let found_span = found_did. and_then ( |did| {
722
724
self . tcx . hir . span_if_local ( did)
723
725
} ) ;
724
726
@@ -727,23 +729,57 @@ impl<'a, 'gcx, 'tcx> InferCtxt<'a, 'gcx, 'tcx> {
727
729
ty:: TyTuple ( ref tys, _) => tys. len ( ) ,
728
730
_ => 1 ,
729
731
} ;
730
- let arg_ty_count =
732
+ let ( arg_tys , arg_ty_count) =
731
733
match actual_trait_ref. skip_binder ( ) . substs . type_at ( 1 ) . sty {
732
- ty:: TyTuple ( ref tys, _) => tys. len ( ) ,
733
- _ => 1 ,
734
+ ty:: TyTuple ( ref tys, _) =>
735
+ ( tys. iter ( ) . map ( |t| & t. sty ) . collect ( ) , tys. len ( ) ) ,
736
+ ref sty => ( vec ! [ sty] , 1 ) ,
734
737
} ;
735
738
if self_ty_count == arg_ty_count {
736
739
self . report_closure_arg_mismatch ( span,
737
740
found_span,
738
741
expected_trait_ref,
739
742
actual_trait_ref)
740
743
} else {
741
- // Expected `|| { }`, found `|x, y| { }`
742
- // Expected `fn(x) -> ()`, found `|| { }`
744
+ let arg_tuple = if arg_ty_count == 1 {
745
+ arg_tys. first ( ) . and_then ( |t| {
746
+ if let & & ty:: TyTuple ( ref tuptys, _) = t {
747
+ Some ( tuptys. len ( ) )
748
+ } else {
749
+ None
750
+ }
751
+ } )
752
+ } else {
753
+ None
754
+ } ;
755
+
756
+ // FIXME(#44150): Expand this to "N args expected bug a N-tuple found".
757
+ // Type of the 1st expected argument is somehow provided as type of a
758
+ // found one in that case.
759
+ //
760
+ // ```
761
+ // [1i32, 2, 3].sort_by(|(a, b)| ..)
762
+ // // ^^^^^^^^
763
+ // // actual_trait_ref: std::ops::FnMut<(&i32, &i32)>
764
+ // // expected_trait_ref: std::ops::FnMut<(&i32,)>
765
+ // ```
766
+
767
+ let closure_args_span = found_did. and_then ( |did| self . tcx . hir . get_if_local ( did) )
768
+ . and_then ( |node| {
769
+ if let hir:: map:: NodeExpr (
770
+ & hir:: Expr { node : hir:: ExprClosure ( _, _, _, span, _) , .. } ) = node
771
+ {
772
+ Some ( span)
773
+ } else {
774
+ None
775
+ }
776
+ } ) ;
777
+
743
778
self . report_arg_count_mismatch (
744
779
span,
745
- found_span,
780
+ closure_args_span . or ( found_span) ,
746
781
arg_ty_count,
782
+ arg_tuple,
747
783
self_ty_count,
748
784
expected_trait_ty. is_closure ( )
749
785
)
@@ -771,28 +807,42 @@ impl<'a, 'gcx, 'tcx> InferCtxt<'a, 'gcx, 'tcx> {
771
807
span : Span ,
772
808
found_span : Option < Span > ,
773
809
expected : usize ,
810
+ expected_tuple : Option < usize > ,
774
811
found : usize ,
775
812
is_closure : bool )
776
813
-> DiagnosticBuilder < ' tcx >
777
814
{
815
+ let kind = if is_closure { "closure" } else { "function" } ;
816
+
817
+ let tuple_or_args = |tuple, args| if let Some ( n) = tuple {
818
+ format ! ( "a {}-tuple" , n)
819
+ } else {
820
+ format ! (
821
+ "{} argument{}" ,
822
+ args,
823
+ if args == 1 { "" } else { "s" }
824
+ )
825
+ } ;
826
+
827
+ let found_str = tuple_or_args ( None , found) ;
828
+ let expected_str = tuple_or_args ( expected_tuple, expected) ;
829
+
778
830
let mut err = struct_span_err ! ( self . tcx. sess, span, E0593 ,
779
- "{} takes {} argument{} but {} argument{} {} required" ,
780
- if is_closure { "closure" } else { "function" } ,
781
- found,
782
- if found == 1 { "" } else { "s" } ,
783
- expected,
784
- if expected == 1 { "" } else { "s" } ,
785
- if expected == 1 { "is" } else { "are" } ) ;
786
-
787
- err. span_label ( span, format ! ( "expected {} that takes {} argument{}" ,
788
- if is_closure { "closure" } else { "function" } ,
789
- expected,
790
- if expected == 1 { "" } else { "s" } ) ) ;
831
+ "{} takes {} but {} {} required" ,
832
+ kind,
833
+ found_str,
834
+ expected_str,
835
+ if expected_tuple. is_some( ) || expected == 1 { "is" } else { "are" } ) ;
836
+
837
+ err. span_label (
838
+ span,
839
+ format ! ( "expected {} that takes {}" , kind, expected_str)
840
+ ) ;
841
+
791
842
if let Some ( span) = found_span {
792
- err. span_label ( span, format ! ( "takes {} argument{}" ,
793
- found,
794
- if found == 1 { "" } else { "s" } ) ) ;
843
+ err. span_label ( span, format ! ( "takes {}" , found_str) ) ;
795
844
}
845
+
796
846
err
797
847
}
798
848
0 commit comments