Skip to content

Commit 6061b8c

Browse files
committed
DeepRejectCtxt: match over lhs only
1 parent 35ca8f8 commit 6061b8c

File tree

1 file changed

+159
-117
lines changed

1 file changed

+159
-117
lines changed

compiler/rustc_type_ir/src/fast_reject.rs

Lines changed: 159 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -230,143 +230,189 @@ impl<I: Interner, const TREAT_LHS_PARAMS: bool, const TREAT_RHS_PARAMS: bool>
230230
}
231231

232232
pub fn types_may_unify(self, lhs: I::Ty, rhs: I::Ty) -> bool {
233-
match (lhs.kind(), rhs.kind()) {
234-
(ty::Ref(_, lhs_ty, lhs_mutbl), ty::Ref(_, rhs_ty, rhs_mutbl)) => {
235-
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
236-
}
233+
if let ty::Infer(var) = rhs.kind() {
234+
return self.var_and_ty_may_unify(var, lhs);
235+
}
237236

238-
(ty::Adt(lhs_def, lhs_args), ty::Adt(rhs_def, rhs_args)) => {
239-
lhs_def == rhs_def && self.args_may_unify(lhs_args, rhs_args)
240-
}
237+
if !Self::type_is_rigid::<TREAT_RHS_PARAMS>(rhs) {
238+
return true;
239+
}
241240

242-
(ty::Infer(var), _) => self.var_and_ty_may_unify(var, rhs),
243-
(_, ty::Infer(var)) => self.var_and_ty_may_unify(var, lhs),
241+
match lhs.kind() {
242+
ty::Ref(_, lhs_ty, lhs_mutbl) => match rhs.kind() {
243+
ty::Ref(_, rhs_ty, rhs_mutbl) => {
244+
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
245+
}
246+
_ => false,
247+
},
248+
249+
ty::Adt(lhs_def, lhs_args) => match rhs.kind() {
250+
ty::Adt(rhs_def, rhs_args) => {
251+
lhs_def == rhs_def && self.args_may_unify(lhs_args, rhs_args)
252+
}
253+
_ => false,
254+
},
244255

245-
(ty::Int(_), ty::Int(_)) | (ty::Uint(_), ty::Uint(_)) => lhs == rhs,
256+
ty::Param(lhs) => match rhs.kind() {
257+
ty::Param(rhs) => match (TREAT_LHS_PARAMS, TREAT_RHS_PARAMS) {
258+
(false, false) => lhs == rhs,
259+
(true, _) | (_, true) => true,
260+
},
261+
_ => TREAT_LHS_PARAMS,
262+
},
246263

247-
(ty::Param(lhs), ty::Param(rhs)) => match (TREAT_LHS_PARAMS, TREAT_RHS_PARAMS) {
248-
(false, false) => lhs == rhs,
249-
(true, _) | (_, true) => true,
264+
ty::Placeholder(lhs) => match rhs.kind() {
265+
ty::Placeholder(rhs) => lhs == rhs,
266+
_ => false,
250267
},
251268

269+
ty::Infer(var) => self.var_and_ty_may_unify(var, rhs),
270+
252271
// As we're walking the whole type, it may encounter projections
253272
// inside of binders and what not, so we're just going to assume that
254273
// projections can unify with other stuff.
255274
//
256275
// Looking forward to lazy normalization this is the safer strategy anyways.
257-
(ty::Alias(..), _) | (_, ty::Alias(..)) => true,
276+
ty::Alias(..) => true,
277+
278+
ty::Uint(_)
279+
| ty::Int(_)
280+
| ty::Float(_)
281+
| ty::Str
282+
| ty::Bool
283+
| ty::Char
284+
| ty::Never
285+
| ty::Foreign(_) => lhs == rhs,
286+
287+
ty::Tuple(lhs) => match rhs.kind() {
288+
ty::Tuple(rhs) => {
289+
lhs.len() == rhs.len()
290+
&& iter::zip(lhs.iter(), rhs.iter())
291+
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
292+
}
293+
_ => false,
294+
},
258295

259-
(ty::Bound(..), _) | (_, ty::Bound(..)) => true,
296+
ty::Array(lhs_ty, lhs_len) => match rhs.kind() {
297+
ty::Array(rhs_ty, rhs_len) => {
298+
self.types_may_unify(lhs_ty, rhs_ty) && self.consts_may_unify(lhs_len, rhs_len)
299+
}
300+
_ => false,
301+
},
260302

261-
(ty::Param(_), _) => TREAT_LHS_PARAMS,
262-
(_, ty::Param(_)) => TREAT_RHS_PARAMS,
303+
ty::RawPtr(lhs_ty, lhs_mutbl) => match rhs.kind() {
304+
ty::RawPtr(rhs_ty, rhs_mutbl) => {
305+
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
306+
}
307+
_ => false,
308+
},
263309

264-
(ty::Tuple(lhs), ty::Tuple(rhs)) => {
265-
lhs.len() == rhs.len()
266-
&& iter::zip(lhs.iter(), rhs.iter())
267-
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
268-
}
310+
ty::Slice(lhs_ty) => match rhs.kind() {
311+
ty::Slice(rhs_ty) => self.types_may_unify(lhs_ty, rhs_ty),
312+
_ => false,
313+
},
269314

270-
(ty::Array(lhs_ty, lhs_len), ty::Array(rhs_ty, rhs_len)) => {
271-
self.types_may_unify(lhs_ty, rhs_ty) && self.consts_may_unify(lhs_len, rhs_len)
272-
}
315+
ty::Dynamic(lhs_preds, ..) => match rhs.kind() {
316+
ty::Dynamic(rhs_preds, ..) => {
317+
// Ideally we would walk the existential predicates here or at least
318+
// compare their length. But considering that the relevant `Relate` impl
319+
// actually sorts and deduplicates these, that doesn't work.
320+
lhs_preds.principal_def_id() == rhs_preds.principal_def_id()
321+
}
322+
_ => false,
323+
},
273324

274-
(ty::RawPtr(lhs_ty, lhs_mutbl), ty::RawPtr(rhs_ty, rhs_mutbl)) => {
275-
lhs_mutbl == rhs_mutbl && self.types_may_unify(lhs_ty, rhs_ty)
276-
}
325+
ty::FnPtr(lhs_sig_tys, lhs_hdr) => match rhs.kind() {
326+
ty::FnPtr(rhs_sig_tys, rhs_hdr) => {
327+
let lhs_sig_tys = lhs_sig_tys.skip_binder().inputs_and_output;
328+
let rhs_sig_tys = rhs_sig_tys.skip_binder().inputs_and_output;
277329

278-
(ty::Slice(lhs_ty), ty::Slice(rhs_ty)) => self.types_may_unify(lhs_ty, rhs_ty),
330+
lhs_hdr == rhs_hdr
331+
&& lhs_sig_tys.len() == rhs_sig_tys.len()
332+
&& iter::zip(lhs_sig_tys.iter(), rhs_sig_tys.iter())
333+
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
334+
}
335+
_ => false,
336+
},
279337

280-
(ty::Float(_), ty::Float(_))
281-
| (ty::Str, ty::Str)
282-
| (ty::Bool, ty::Bool)
283-
| (ty::Char, ty::Char)
284-
| (ty::Never, ty::Never)
285-
| (ty::Foreign(_), ty::Foreign(_)) => lhs == rhs,
338+
ty::Bound(..) => true,
286339

287-
(ty::Dynamic(lhs_preds, ..), ty::Dynamic(rhs_preds, ..)) => {
288-
// Ideally we would walk the existential predicates here or at least
289-
// compare their length. But considering that the relevant `Relate` impl
290-
// actually sorts and deduplicates these, that doesn't work.
291-
lhs_preds.principal_def_id() == rhs_preds.principal_def_id()
292-
}
340+
ty::FnDef(lhs_def_id, lhs_args) => match rhs.kind() {
341+
ty::FnDef(rhs_def_id, rhs_args) => {
342+
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
343+
}
344+
_ => false,
345+
},
293346

294-
// Placeholder types don't unify with anything on their own.
295-
(ty::Placeholder(lhs), ty::Placeholder(rhs)) => lhs == rhs,
347+
ty::Closure(lhs_def_id, lhs_args) => match rhs.kind() {
348+
ty::Closure(rhs_def_id, rhs_args) => {
349+
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
350+
}
351+
_ => false,
352+
},
296353

297-
(ty::FnPtr(lhs_sig_tys, lhs_hdr), ty::FnPtr(rhs_sig_tys, rhs_hdr)) => {
298-
let lhs_sig_tys = lhs_sig_tys.skip_binder().inputs_and_output;
299-
let rhs_sig_tys = rhs_sig_tys.skip_binder().inputs_and_output;
354+
ty::CoroutineClosure(lhs_def_id, lhs_args) => match rhs.kind() {
355+
ty::CoroutineClosure(rhs_def_id, rhs_args) => {
356+
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
357+
}
358+
_ => false,
359+
},
300360

301-
lhs_hdr == rhs_hdr
302-
&& lhs_sig_tys.len() == rhs_sig_tys.len()
303-
&& iter::zip(lhs_sig_tys.iter(), rhs_sig_tys.iter())
304-
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
305-
}
361+
ty::Coroutine(lhs_def_id, lhs_args) => match rhs.kind() {
362+
ty::Coroutine(rhs_def_id, rhs_args) => {
363+
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
364+
}
365+
_ => false,
366+
},
306367

307-
(ty::FnDef(lhs_def_id, lhs_args), ty::FnDef(rhs_def_id, rhs_args))
308-
| (ty::Closure(lhs_def_id, lhs_args), ty::Closure(rhs_def_id, rhs_args))
309-
| (
310-
ty::CoroutineClosure(lhs_def_id, lhs_args),
311-
ty::CoroutineClosure(rhs_def_id, rhs_args),
312-
)
313-
| (ty::Coroutine(lhs_def_id, lhs_args), ty::Coroutine(rhs_def_id, rhs_args))
314-
| (
315-
ty::CoroutineWitness(lhs_def_id, lhs_args),
316-
ty::CoroutineWitness(rhs_def_id, rhs_args),
317-
) => lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args),
318-
319-
(ty::Pat(lhs_ty, _), ty::Pat(rhs_ty, _)) => {
320-
// FIXME(pattern_types): take pattern into account
321-
self.types_may_unify(lhs_ty, rhs_ty)
322-
}
368+
ty::CoroutineWitness(lhs_def_id, lhs_args) => match rhs.kind() {
369+
ty::CoroutineWitness(rhs_def_id, rhs_args) => {
370+
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
371+
}
372+
_ => false,
373+
},
374+
375+
ty::Pat(lhs_ty, _) => match rhs.kind() {
376+
ty::Pat(rhs_ty, _) => {
377+
// FIXME(pattern_types): take pattern into account
378+
self.types_may_unify(lhs_ty, rhs_ty)
379+
}
380+
_ => false,
381+
},
323382

324-
(ty::Error(..), _)
325-
| (_, ty::Error(..))
326-
| (ty::Placeholder(_), _)
327-
| (_, ty::Placeholder(_))
328-
| (ty::Bool, _)
329-
| (_, ty::Bool)
330-
| (ty::Char, _)
331-
| (_, ty::Char)
332-
| (ty::Int(_), _)
333-
| (_, ty::Int(_))
334-
| (ty::Uint(_), _)
335-
| (_, ty::Uint(_))
336-
| (ty::Float(_), _)
337-
| (_, ty::Float(_))
338-
| (ty::Str, _)
339-
| (_, ty::Str)
340-
| (ty::Never, _)
341-
| (_, ty::Never)
342-
| (ty::Foreign(_), _)
343-
| (_, ty::Foreign(_))
344-
| (ty::Ref(..), _)
345-
| (_, ty::Ref(..))
346-
| (ty::Adt(..), _)
347-
| (_, ty::Adt(..))
348-
| (ty::Pat(..), _)
349-
| (_, ty::Pat(..))
350-
| (ty::Slice(_), _)
351-
| (_, ty::Slice(_))
352-
| (ty::Array(..), _)
353-
| (_, ty::Array(..))
354-
| (ty::Tuple(_), _)
355-
| (_, ty::Tuple(_))
356-
| (ty::RawPtr(..), _)
357-
| (_, ty::RawPtr(..))
358-
| (ty::Dynamic(..), _)
359-
| (_, ty::Dynamic(..))
360-
| (ty::FnPtr(..), _)
361-
| (_, ty::FnPtr(..))
362-
| (ty::FnDef(..), _)
363-
| (_, ty::FnDef(..))
364-
| (ty::Closure(..), _)
365-
| (_, ty::Closure(..))
366-
| (ty::CoroutineClosure(..), _)
367-
| (_, ty::CoroutineClosure(..))
368-
| (ty::Coroutine(..), _)
369-
| (_, ty::Coroutine(..)) => false,
383+
ty::Error(..) => true,
384+
}
385+
}
386+
387+
fn type_is_rigid<const TREAT_PARAM: bool>(ty: I::Ty) -> bool {
388+
match ty.kind() {
389+
ty::Bool
390+
| ty::Char
391+
| ty::Int(_)
392+
| ty::Uint(_)
393+
| ty::Float(_)
394+
| ty::Adt(_, _)
395+
| ty::Foreign(_)
396+
| ty::Str
397+
| ty::Array(_, _)
398+
| ty::Pat(_, _)
399+
| ty::Slice(_)
400+
| ty::RawPtr(_, _)
401+
| ty::Ref(_, _, _)
402+
| ty::FnDef(_, _)
403+
| ty::FnPtr(..)
404+
| ty::Dynamic(_, _, _)
405+
| ty::Closure(_, _)
406+
| ty::CoroutineClosure(_, _)
407+
| ty::Coroutine(_, _)
408+
| ty::CoroutineWitness(..)
409+
| ty::Never
410+
| ty::Tuple(_)
411+
| ty::Placeholder(_) => true,
412+
413+
ty::Param(_) => !TREAT_PARAM,
414+
415+
ty::Error(_) | ty::Infer(_) | ty::Alias(_, _) | ty::Bound(_, _) => false,
370416
}
371417
}
372418

@@ -389,10 +435,6 @@ impl<I: Interner, const TREAT_LHS_PARAMS: bool, const TREAT_RHS_PARAMS: bool>
389435
}
390436

391437
fn var_and_ty_may_unify(self, var: ty::InferTy, ty: I::Ty) -> bool {
392-
if !ty.is_known_rigid() {
393-
return true;
394-
}
395-
396438
match var {
397439
ty::IntVar(_) => ty.is_integral(),
398440
ty::FloatVar(_) => ty.is_floating_point(),

0 commit comments

Comments
 (0)