| 
1 | 1 | use crate::FnCtxt;  | 
2 | 2 | use rustc_hir as hir;  | 
3 | 3 | use rustc_hir::def::Res;  | 
4 |  | -use rustc_middle::ty::{self, DefIdTree, Ty};  | 
 | 4 | +use rustc_hir::def_id::DefId;  | 
 | 5 | +use rustc_infer::traits::ObligationCauseCode;  | 
 | 6 | +use rustc_middle::ty::{self, DefIdTree, Ty, TypeSuperVisitable, TypeVisitable, TypeVisitor};  | 
 | 7 | +use rustc_span::{self, Span};  | 
5 | 8 | use rustc_trait_selection::traits;  | 
6 | 9 | 
 
  | 
 | 10 | +use std::ops::ControlFlow;  | 
 | 11 | + | 
7 | 12 | impl<'a, 'tcx> FnCtxt<'a, 'tcx> {  | 
 | 13 | +    pub fn adjust_fulfillment_error_for_expr_obligation(  | 
 | 14 | +        &self,  | 
 | 15 | +        error: &mut traits::FulfillmentError<'tcx>,  | 
 | 16 | +    ) -> bool {  | 
 | 17 | +        let (traits::ExprItemObligation(def_id, hir_id, idx) | traits::ExprBindingObligation(def_id, _, hir_id, idx))  | 
 | 18 | +            = *error.obligation.cause.code().peel_derives() else { return false; };  | 
 | 19 | +        let hir = self.tcx.hir();  | 
 | 20 | +        let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };  | 
 | 21 | + | 
 | 22 | +        let Some(unsubstituted_pred) =  | 
 | 23 | +            self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)  | 
 | 24 | +            else { return false; };  | 
 | 25 | + | 
 | 26 | +        let generics = self.tcx.generics_of(def_id);  | 
 | 27 | +        let predicate_substs = match unsubstituted_pred.kind().skip_binder() {  | 
 | 28 | +            ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => pred.trait_ref.substs,  | 
 | 29 | +            ty::PredicateKind::Clause(ty::Clause::Projection(pred)) => pred.projection_ty.substs,  | 
 | 30 | +            _ => ty::List::empty(),  | 
 | 31 | +        };  | 
 | 32 | + | 
 | 33 | +        let find_param_matching = |matches: &dyn Fn(&ty::ParamTy) -> bool| {  | 
 | 34 | +            predicate_substs.types().find_map(|ty| {  | 
 | 35 | +                ty.walk().find_map(|arg| {  | 
 | 36 | +                    if let ty::GenericArgKind::Type(ty) = arg.unpack()  | 
 | 37 | +                        && let ty::Param(param_ty) = ty.kind()  | 
 | 38 | +                        && matches(param_ty)  | 
 | 39 | +                    {  | 
 | 40 | +                        Some(arg)  | 
 | 41 | +                    } else {  | 
 | 42 | +                        None  | 
 | 43 | +                    }  | 
 | 44 | +                })  | 
 | 45 | +            })  | 
 | 46 | +        };  | 
 | 47 | + | 
 | 48 | +        // Prefer generics that are local to the fn item, since these are likely  | 
 | 49 | +        // to be the cause of the unsatisfied predicate.  | 
 | 50 | +        let mut param_to_point_at = find_param_matching(&|param_ty| {  | 
 | 51 | +            self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) == def_id  | 
 | 52 | +        });  | 
 | 53 | +        // Fall back to generic that isn't local to the fn item. This will come  | 
 | 54 | +        // from a trait or impl, for example.  | 
 | 55 | +        let mut fallback_param_to_point_at = find_param_matching(&|param_ty| {  | 
 | 56 | +            self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) != def_id  | 
 | 57 | +                && param_ty.name != rustc_span::symbol::kw::SelfUpper  | 
 | 58 | +        });  | 
 | 59 | +        // Finally, the `Self` parameter is possibly the reason that the predicate  | 
 | 60 | +        // is unsatisfied. This is less likely to be true for methods, because  | 
 | 61 | +        // method probe means that we already kinda check that the predicates due  | 
 | 62 | +        // to the `Self` type are true.  | 
 | 63 | +        let mut self_param_to_point_at =  | 
 | 64 | +            find_param_matching(&|param_ty| param_ty.name == rustc_span::symbol::kw::SelfUpper);  | 
 | 65 | + | 
 | 66 | +        // Finally, for ambiguity-related errors, we actually want to look  | 
 | 67 | +        // for a parameter that is the source of the inference type left  | 
 | 68 | +        // over in this predicate.  | 
 | 69 | +        if let traits::FulfillmentErrorCode::CodeAmbiguity = error.code {  | 
 | 70 | +            fallback_param_to_point_at = None;  | 
 | 71 | +            self_param_to_point_at = None;  | 
 | 72 | +            param_to_point_at =  | 
 | 73 | +                self.find_ambiguous_parameter_in(def_id, error.root_obligation.predicate);  | 
 | 74 | +        }  | 
 | 75 | + | 
 | 76 | +        if self.closure_span_overlaps_error(error, expr.span) {  | 
 | 77 | +            return false;  | 
 | 78 | +        }  | 
 | 79 | + | 
 | 80 | +        match &expr.kind {  | 
 | 81 | +            hir::ExprKind::Path(qpath) => {  | 
 | 82 | +                if let hir::Node::Expr(hir::Expr {  | 
 | 83 | +                    kind: hir::ExprKind::Call(callee, args),  | 
 | 84 | +                    hir_id: call_hir_id,  | 
 | 85 | +                    span: call_span,  | 
 | 86 | +                    ..  | 
 | 87 | +                }) = hir.get_parent(expr.hir_id)  | 
 | 88 | +                    && callee.hir_id == expr.hir_id  | 
 | 89 | +                {  | 
 | 90 | +                    if self.closure_span_overlaps_error(error, *call_span) {  | 
 | 91 | +                        return false;  | 
 | 92 | +                    }  | 
 | 93 | + | 
 | 94 | +                    for param in  | 
 | 95 | +                        [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]  | 
 | 96 | +                        .into_iter()  | 
 | 97 | +                        .flatten()  | 
 | 98 | +                    {  | 
 | 99 | +                        if self.blame_specific_arg_if_possible(  | 
 | 100 | +                                error,  | 
 | 101 | +                                def_id,  | 
 | 102 | +                                param,  | 
 | 103 | +                                *call_hir_id,  | 
 | 104 | +                                callee.span,  | 
 | 105 | +                                None,  | 
 | 106 | +                                args,  | 
 | 107 | +                            )  | 
 | 108 | +                        {  | 
 | 109 | +                            return true;  | 
 | 110 | +                        }  | 
 | 111 | +                    }  | 
 | 112 | +                }  | 
 | 113 | +                // Notably, we only point to params that are local to the  | 
 | 114 | +                // item we're checking, since those are the ones we are able  | 
 | 115 | +                // to look in the final `hir::PathSegment` for. Everything else  | 
 | 116 | +                // would require a deeper search into the `qpath` than I think  | 
 | 117 | +                // is worthwhile.  | 
 | 118 | +                if let Some(param_to_point_at) = param_to_point_at  | 
 | 119 | +                    && self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)  | 
 | 120 | +                {  | 
 | 121 | +                    return true;  | 
 | 122 | +                }  | 
 | 123 | +            }  | 
 | 124 | +            hir::ExprKind::MethodCall(segment, receiver, args, ..) => {  | 
 | 125 | +                for param in [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]  | 
 | 126 | +                    .into_iter()  | 
 | 127 | +                    .flatten()  | 
 | 128 | +                {  | 
 | 129 | +                    if self.blame_specific_arg_if_possible(  | 
 | 130 | +                        error,  | 
 | 131 | +                        def_id,  | 
 | 132 | +                        param,  | 
 | 133 | +                        hir_id,  | 
 | 134 | +                        segment.ident.span,  | 
 | 135 | +                        Some(receiver),  | 
 | 136 | +                        args,  | 
 | 137 | +                    ) {  | 
 | 138 | +                        return true;  | 
 | 139 | +                    }  | 
 | 140 | +                }  | 
 | 141 | +                if let Some(param_to_point_at) = param_to_point_at  | 
 | 142 | +                    && self.point_at_generic_if_possible(error, def_id, param_to_point_at, segment)  | 
 | 143 | +                {  | 
 | 144 | +                    return true;  | 
 | 145 | +                }  | 
 | 146 | +            }  | 
 | 147 | +            hir::ExprKind::Struct(qpath, fields, ..) => {  | 
 | 148 | +                if let Res::Def(  | 
 | 149 | +                    hir::def::DefKind::Struct | hir::def::DefKind::Variant,  | 
 | 150 | +                    variant_def_id,  | 
 | 151 | +                ) = self.typeck_results.borrow().qpath_res(qpath, hir_id)  | 
 | 152 | +                {  | 
 | 153 | +                    for param in  | 
 | 154 | +                        [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]  | 
 | 155 | +                    {  | 
 | 156 | +                        if let Some(param) = param {  | 
 | 157 | +                            let refined_expr = self.point_at_field_if_possible(  | 
 | 158 | +                                def_id,  | 
 | 159 | +                                param,  | 
 | 160 | +                                variant_def_id,  | 
 | 161 | +                                fields,  | 
 | 162 | +                            );  | 
 | 163 | + | 
 | 164 | +                            match refined_expr {  | 
 | 165 | +                                None => {}  | 
 | 166 | +                                Some((refined_expr, _)) => {  | 
 | 167 | +                                    error.obligation.cause.span = refined_expr  | 
 | 168 | +                                        .span  | 
 | 169 | +                                        .find_ancestor_in_same_ctxt(error.obligation.cause.span)  | 
 | 170 | +                                        .unwrap_or(refined_expr.span);  | 
 | 171 | +                                    return true;  | 
 | 172 | +                                }  | 
 | 173 | +                            }  | 
 | 174 | +                        }  | 
 | 175 | +                    }  | 
 | 176 | +                }  | 
 | 177 | +                if let Some(param_to_point_at) = param_to_point_at  | 
 | 178 | +                    && self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)  | 
 | 179 | +                {  | 
 | 180 | +                    return true;  | 
 | 181 | +                }  | 
 | 182 | +            }  | 
 | 183 | +            _ => {}  | 
 | 184 | +        }  | 
 | 185 | + | 
 | 186 | +        false  | 
 | 187 | +    }  | 
 | 188 | + | 
 | 189 | +    fn point_at_path_if_possible(  | 
 | 190 | +        &self,  | 
 | 191 | +        error: &mut traits::FulfillmentError<'tcx>,  | 
 | 192 | +        def_id: DefId,  | 
 | 193 | +        param: ty::GenericArg<'tcx>,  | 
 | 194 | +        qpath: &hir::QPath<'tcx>,  | 
 | 195 | +    ) -> bool {  | 
 | 196 | +        match qpath {  | 
 | 197 | +            hir::QPath::Resolved(_, path) => {  | 
 | 198 | +                if let Some(segment) = path.segments.last()  | 
 | 199 | +                    && self.point_at_generic_if_possible(error, def_id, param, segment)  | 
 | 200 | +                {  | 
 | 201 | +                    return true;  | 
 | 202 | +                }  | 
 | 203 | +            }  | 
 | 204 | +            hir::QPath::TypeRelative(_, segment) => {  | 
 | 205 | +                if self.point_at_generic_if_possible(error, def_id, param, segment) {  | 
 | 206 | +                    return true;  | 
 | 207 | +                }  | 
 | 208 | +            }  | 
 | 209 | +            _ => {}  | 
 | 210 | +        }  | 
 | 211 | + | 
 | 212 | +        false  | 
 | 213 | +    }  | 
 | 214 | + | 
 | 215 | +    fn point_at_generic_if_possible(  | 
 | 216 | +        &self,  | 
 | 217 | +        error: &mut traits::FulfillmentError<'tcx>,  | 
 | 218 | +        def_id: DefId,  | 
 | 219 | +        param_to_point_at: ty::GenericArg<'tcx>,  | 
 | 220 | +        segment: &hir::PathSegment<'tcx>,  | 
 | 221 | +    ) -> bool {  | 
 | 222 | +        let own_substs = self  | 
 | 223 | +            .tcx  | 
 | 224 | +            .generics_of(def_id)  | 
 | 225 | +            .own_substs(ty::InternalSubsts::identity_for_item(self.tcx, def_id));  | 
 | 226 | +        let Some((index, _)) = own_substs  | 
 | 227 | +            .iter()  | 
 | 228 | +            .filter(|arg| matches!(arg.unpack(), ty::GenericArgKind::Type(_)))  | 
 | 229 | +            .enumerate()  | 
 | 230 | +            .find(|(_, arg)| **arg == param_to_point_at) else { return false };  | 
 | 231 | +        let Some(arg) = segment  | 
 | 232 | +            .args()  | 
 | 233 | +            .args  | 
 | 234 | +            .iter()  | 
 | 235 | +            .filter(|arg| matches!(arg, hir::GenericArg::Type(_)))  | 
 | 236 | +            .nth(index) else { return false; };  | 
 | 237 | +        error.obligation.cause.span = arg  | 
 | 238 | +            .span()  | 
 | 239 | +            .find_ancestor_in_same_ctxt(error.obligation.cause.span)  | 
 | 240 | +            .unwrap_or(arg.span());  | 
 | 241 | +        true  | 
 | 242 | +    }  | 
 | 243 | + | 
 | 244 | +    fn find_ambiguous_parameter_in<T: TypeVisitable<'tcx>>(  | 
 | 245 | +        &self,  | 
 | 246 | +        item_def_id: DefId,  | 
 | 247 | +        t: T,  | 
 | 248 | +    ) -> Option<ty::GenericArg<'tcx>> {  | 
 | 249 | +        struct FindAmbiguousParameter<'a, 'tcx>(&'a FnCtxt<'a, 'tcx>, DefId);  | 
 | 250 | +        impl<'tcx> TypeVisitor<'tcx> for FindAmbiguousParameter<'_, 'tcx> {  | 
 | 251 | +            type BreakTy = ty::GenericArg<'tcx>;  | 
 | 252 | +            fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {  | 
 | 253 | +                if let Some(origin) = self.0.type_var_origin(ty)  | 
 | 254 | +                    && let rustc_infer::infer::type_variable::TypeVariableOriginKind::TypeParameterDefinition(_, Some(def_id)) =  | 
 | 255 | +                        origin.kind  | 
 | 256 | +                    && let generics = self.0.tcx.generics_of(self.1)  | 
 | 257 | +                    && let Some(index) = generics.param_def_id_to_index(self.0.tcx, def_id)  | 
 | 258 | +                    && let Some(subst) = ty::InternalSubsts::identity_for_item(self.0.tcx, self.1)  | 
 | 259 | +                        .get(index as usize)  | 
 | 260 | +                {  | 
 | 261 | +                    ControlFlow::Break(*subst)  | 
 | 262 | +                } else {  | 
 | 263 | +                    ty.super_visit_with(self)  | 
 | 264 | +                }  | 
 | 265 | +            }  | 
 | 266 | +        }  | 
 | 267 | +        t.visit_with(&mut FindAmbiguousParameter(self, item_def_id)).break_value()  | 
 | 268 | +    }  | 
 | 269 | + | 
 | 270 | +    fn closure_span_overlaps_error(  | 
 | 271 | +        &self,  | 
 | 272 | +        error: &traits::FulfillmentError<'tcx>,  | 
 | 273 | +        span: Span,  | 
 | 274 | +    ) -> bool {  | 
 | 275 | +        if let traits::FulfillmentErrorCode::CodeSelectionError(  | 
 | 276 | +            traits::SelectionError::OutputTypeParameterMismatch(_, expected, _),  | 
 | 277 | +        ) = error.code  | 
 | 278 | +            && let ty::Closure(def_id, _) | ty::Generator(def_id, ..) = expected.skip_binder().self_ty().kind()  | 
 | 279 | +            && span.overlaps(self.tcx.def_span(*def_id))  | 
 | 280 | +        {  | 
 | 281 | +            true  | 
 | 282 | +        } else {  | 
 | 283 | +            false  | 
 | 284 | +        }  | 
 | 285 | +    }  | 
 | 286 | + | 
 | 287 | +    fn point_at_field_if_possible(  | 
 | 288 | +        &self,  | 
 | 289 | +        def_id: DefId,  | 
 | 290 | +        param_to_point_at: ty::GenericArg<'tcx>,  | 
 | 291 | +        variant_def_id: DefId,  | 
 | 292 | +        expr_fields: &[hir::ExprField<'tcx>],  | 
 | 293 | +    ) -> Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)> {  | 
 | 294 | +        let def = self.tcx.adt_def(def_id);  | 
 | 295 | + | 
 | 296 | +        let identity_substs = ty::InternalSubsts::identity_for_item(self.tcx, def_id);  | 
 | 297 | +        let fields_referencing_param: Vec<_> = def  | 
 | 298 | +            .variant_with_id(variant_def_id)  | 
 | 299 | +            .fields  | 
 | 300 | +            .iter()  | 
 | 301 | +            .filter(|field| {  | 
 | 302 | +                let field_ty = field.ty(self.tcx, identity_substs);  | 
 | 303 | +                Self::find_param_in_ty(field_ty.into(), param_to_point_at)  | 
 | 304 | +            })  | 
 | 305 | +            .collect();  | 
 | 306 | + | 
 | 307 | +        if let [field] = fields_referencing_param.as_slice() {  | 
 | 308 | +            for expr_field in expr_fields {  | 
 | 309 | +                // Look for the ExprField that matches the field, using the  | 
 | 310 | +                // same rules that check_expr_struct uses for macro hygiene.  | 
 | 311 | +                if self.tcx.adjust_ident(expr_field.ident, variant_def_id) == field.ident(self.tcx)  | 
 | 312 | +                {  | 
 | 313 | +                    return Some((expr_field.expr, self.tcx.type_of(field.did)));  | 
 | 314 | +                }  | 
 | 315 | +            }  | 
 | 316 | +        }  | 
 | 317 | + | 
 | 318 | +        None  | 
 | 319 | +    }  | 
 | 320 | + | 
 | 321 | +    /// - `blame_specific_*` means that the function will recursively traverse the expression,  | 
 | 322 | +    /// looking for the most-specific-possible span to blame.  | 
 | 323 | +    ///  | 
 | 324 | +    /// - `point_at_*` means that the function will only go "one level", pointing at the specific  | 
 | 325 | +    /// expression mentioned.  | 
 | 326 | +    ///  | 
 | 327 | +    /// `blame_specific_arg_if_possible` will find the most-specific expression anywhere inside  | 
 | 328 | +    /// the provided function call expression, and mark it as responsible for the fullfillment  | 
 | 329 | +    /// error.  | 
 | 330 | +    fn blame_specific_arg_if_possible(  | 
 | 331 | +        &self,  | 
 | 332 | +        error: &mut traits::FulfillmentError<'tcx>,  | 
 | 333 | +        def_id: DefId,  | 
 | 334 | +        param_to_point_at: ty::GenericArg<'tcx>,  | 
 | 335 | +        call_hir_id: hir::HirId,  | 
 | 336 | +        callee_span: Span,  | 
 | 337 | +        receiver: Option<&'tcx hir::Expr<'tcx>>,  | 
 | 338 | +        args: &'tcx [hir::Expr<'tcx>],  | 
 | 339 | +    ) -> bool {  | 
 | 340 | +        let ty = self.tcx.type_of(def_id);  | 
 | 341 | +        if !ty.is_fn() {  | 
 | 342 | +            return false;  | 
 | 343 | +        }  | 
 | 344 | +        let sig = ty.fn_sig(self.tcx).skip_binder();  | 
 | 345 | +        let args_referencing_param: Vec<_> = sig  | 
 | 346 | +            .inputs()  | 
 | 347 | +            .iter()  | 
 | 348 | +            .enumerate()  | 
 | 349 | +            .filter(|(_, ty)| Self::find_param_in_ty((**ty).into(), param_to_point_at))  | 
 | 350 | +            .collect();  | 
 | 351 | +        // If there's one field that references the given generic, great!  | 
 | 352 | +        if let [(idx, _)] = args_referencing_param.as_slice()  | 
 | 353 | +            && let Some(arg) = receiver  | 
 | 354 | +                .map_or(args.get(*idx), |rcvr| if *idx == 0 { Some(rcvr) } else { args.get(*idx - 1) }) {  | 
 | 355 | + | 
 | 356 | +            error.obligation.cause.span = arg.span.find_ancestor_in_same_ctxt(error.obligation.cause.span).unwrap_or(arg.span);  | 
 | 357 | + | 
 | 358 | +            if let hir::Node::Expr(arg_expr) = self.tcx.hir().get(arg.hir_id) {  | 
 | 359 | +                // This is more specific than pointing at the entire argument.  | 
 | 360 | +                self.blame_specific_expr_if_possible(error, arg_expr)  | 
 | 361 | +            }  | 
 | 362 | + | 
 | 363 | +            error.obligation.cause.map_code(|parent_code| {  | 
 | 364 | +                ObligationCauseCode::FunctionArgumentObligation {  | 
 | 365 | +                    arg_hir_id: arg.hir_id,  | 
 | 366 | +                    call_hir_id,  | 
 | 367 | +                    parent_code,  | 
 | 368 | +                }  | 
 | 369 | +            });  | 
 | 370 | +            return true;  | 
 | 371 | +        } else if args_referencing_param.len() > 0 {  | 
 | 372 | +            // If more than one argument applies, then point to the callee span at least...  | 
 | 373 | +            // We have chance to fix this up further in `point_at_generics_if_possible`  | 
 | 374 | +            error.obligation.cause.span = callee_span;  | 
 | 375 | +        }  | 
 | 376 | + | 
 | 377 | +        false  | 
 | 378 | +    }  | 
 | 379 | + | 
8 | 380 |     /**  | 
9 | 381 |      * Recursively searches for the most-specific blamable expression.  | 
10 | 382 |      * For example, if you have a chain of constraints like:  | 
 | 
0 commit comments