@@ -12,6 +12,7 @@ use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyC
1212use rustc_session:: config:: OptLevel ;
1313use rustc_span:: def_id:: DefId ;
1414use rustc_span:: { hygiene:: ExpnKind , ExpnData , LocalExpnId , Span } ;
15+ use rustc_target:: abi:: VariantIdx ;
1516use rustc_target:: spec:: abi:: Abi ;
1617
1718use super :: simplify:: { remove_dead_blocks, CfgSimplifier } ;
@@ -414,118 +415,60 @@ impl<'tcx> Inliner<'tcx> {
414415 debug ! ( " final inline threshold = {}" , threshold) ;
415416
416417 // FIXME: Give a bonus to functions with only a single caller
417- let mut first_block = true ;
418- let mut cost = 0 ;
418+ let diverges = matches ! (
419+ callee_body. basic_blocks( ) [ START_BLOCK ] . terminator( ) . kind,
420+ TerminatorKind :: Unreachable | TerminatorKind :: Call { target: None , .. }
421+ ) ;
422+ if diverges && !matches ! ( callee_attrs. inline, InlineAttr :: Always ) {
423+ return Err ( "callee diverges unconditionally" ) ;
424+ }
425+
426+ let mut checker = CostChecker {
427+ tcx : self . tcx ,
428+ param_env : self . param_env ,
429+ instance : callsite. callee ,
430+ callee_body,
431+ cost : 0 ,
432+ validation : Ok ( ( ) ) ,
433+ } ;
419434
420- // Traverse the MIR manually so we can account for the effects of
421- // inlining on the CFG.
435+ // Traverse the MIR manually so we can account for the effects of inlining on the CFG.
422436 let mut work_list = vec ! [ START_BLOCK ] ;
423437 let mut visited = BitSet :: new_empty ( callee_body. basic_blocks ( ) . len ( ) ) ;
424438 while let Some ( bb) = work_list. pop ( ) {
425439 if !visited. insert ( bb. index ( ) ) {
426440 continue ;
427441 }
442+
428443 let blk = & callee_body. basic_blocks ( ) [ bb] ;
444+ checker. visit_basic_block_data ( bb, blk) ;
429445
430- for stmt in & blk. statements {
431- // Don't count StorageLive/StorageDead in the inlining cost.
432- match stmt. kind {
433- StatementKind :: StorageLive ( _)
434- | StatementKind :: StorageDead ( _)
435- | StatementKind :: Deinit ( _)
436- | StatementKind :: Nop => { }
437- _ => cost += INSTR_COST ,
438- }
439- }
440446 let term = blk. terminator ( ) ;
441- let mut is_drop = false ;
442- match term. kind {
443- TerminatorKind :: Drop { ref place, target, unwind }
444- | TerminatorKind :: DropAndReplace { ref place, target, unwind, .. } => {
445- is_drop = true ;
446- work_list. push ( target) ;
447- // If the place doesn't actually need dropping, treat it like
448- // a regular goto.
449- let ty = callsite. callee . subst_mir ( self . tcx , & place. ty ( callee_body, tcx) . ty ) ;
450- if ty. needs_drop ( tcx, self . param_env ) {
451- cost += CALL_PENALTY ;
452- if let Some ( unwind) = unwind {
453- cost += LANDINGPAD_PENALTY ;
454- work_list. push ( unwind) ;
455- }
456- } else {
457- cost += INSTR_COST ;
458- }
459- }
460-
461- TerminatorKind :: Unreachable | TerminatorKind :: Call { target : None , .. }
462- if first_block =>
463- {
464- // If the function always diverges, don't inline
465- // unless the cost is zero
466- threshold = 0 ;
467- }
468-
469- TerminatorKind :: Call { func : Operand :: Constant ( ref f) , cleanup, .. } => {
470- if let ty:: FnDef ( def_id, _) =
471- * callsite. callee . subst_mir ( self . tcx , & f. literal . ty ( ) ) . kind ( )
472- {
473- // Don't give intrinsics the extra penalty for calls
474- if tcx. is_intrinsic ( def_id) {
475- cost += INSTR_COST ;
476- } else {
477- cost += CALL_PENALTY ;
478- }
479- } else {
480- cost += CALL_PENALTY ;
481- }
482- if cleanup. is_some ( ) {
483- cost += LANDINGPAD_PENALTY ;
484- }
485- }
486- TerminatorKind :: Assert { cleanup, .. } => {
487- cost += CALL_PENALTY ;
488-
489- if cleanup. is_some ( ) {
490- cost += LANDINGPAD_PENALTY ;
491- }
492- }
493- TerminatorKind :: Resume => cost += RESUME_PENALTY ,
494- TerminatorKind :: InlineAsm { cleanup, .. } => {
495- cost += INSTR_COST ;
447+ if let TerminatorKind :: Drop { ref place, target, unwind }
448+ | TerminatorKind :: DropAndReplace { ref place, target, unwind, .. } = term. kind
449+ {
450+ work_list. push ( target) ;
496451
497- if cleanup. is_some ( ) {
498- cost += LANDINGPAD_PENALTY ;
452+ // If the place doesn't actually need dropping, treat it like a regular goto.
453+ let ty = callsite. callee . subst_mir ( self . tcx , & place. ty ( callee_body, tcx) . ty ) ;
454+ if ty. needs_drop ( tcx, self . param_env ) && let Some ( unwind) = unwind {
455+ work_list. push ( unwind) ;
499456 }
500- }
501- _ => cost += INSTR_COST ,
502- }
503-
504- if !is_drop {
505- for succ in term. successors ( ) {
506- work_list. push ( succ) ;
507- }
457+ } else {
458+ work_list. extend ( term. successors ( ) )
508459 }
509-
510- first_block = false ;
511460 }
512461
513462 // Count up the cost of local variables and temps, if we know the size
514463 // use that, otherwise we use a moderately-large dummy cost.
515-
516- let ptr_size = tcx. data_layout . pointer_size . bytes ( ) ;
517-
518464 for v in callee_body. vars_and_temps_iter ( ) {
519- let ty = callsite. callee . subst_mir ( self . tcx , & callee_body. local_decls [ v] . ty ) ;
520- // Cost of the var is the size in machine-words, if we know
521- // it.
522- if let Some ( size) = type_size_of ( tcx, self . param_env , ty) {
523- cost += ( ( size + ptr_size - 1 ) / ptr_size) as usize ;
524- } else {
525- cost += UNKNOWN_SIZE_COST ;
526- }
465+ checker. visit_local_decl ( v, & callee_body. local_decls [ v] ) ;
527466 }
528467
468+ // Abort if type validation found anything fishy.
469+ checker. validation ?;
470+
471+ let cost = checker. cost ;
529472 if let InlineAttr :: Always = callee_attrs. inline {
530473 debug ! ( "INLINING {:?} because inline(always) [cost={}]" , callsite, cost) ;
531474 Ok ( ( ) )
@@ -799,6 +742,193 @@ fn type_size_of<'tcx>(
799742 tcx. layout_of ( param_env. and ( ty) ) . ok ( ) . map ( |layout| layout. size . bytes ( ) )
800743}
801744
745+ /// Verify that the callee body is compatible with the caller.
746+ ///
747+ /// This visitor mostly computes the inlining cost,
748+ /// but also needs to verify that types match because of normalization failure.
749+ struct CostChecker < ' b , ' tcx > {
750+ tcx : TyCtxt < ' tcx > ,
751+ param_env : ParamEnv < ' tcx > ,
752+ cost : usize ,
753+ callee_body : & ' b Body < ' tcx > ,
754+ instance : ty:: Instance < ' tcx > ,
755+ validation : Result < ( ) , & ' static str > ,
756+ }
757+
758+ impl < ' tcx > Visitor < ' tcx > for CostChecker < ' _ , ' tcx > {
759+ fn visit_statement ( & mut self , statement : & Statement < ' tcx > , location : Location ) {
760+ // Don't count StorageLive/StorageDead in the inlining cost.
761+ match statement. kind {
762+ StatementKind :: StorageLive ( _)
763+ | StatementKind :: StorageDead ( _)
764+ | StatementKind :: Deinit ( _)
765+ | StatementKind :: Nop => { }
766+ _ => self . cost += INSTR_COST ,
767+ }
768+
769+ self . super_statement ( statement, location) ;
770+ }
771+
772+ fn visit_terminator ( & mut self , terminator : & Terminator < ' tcx > , location : Location ) {
773+ let tcx = self . tcx ;
774+ match terminator. kind {
775+ TerminatorKind :: Drop { ref place, unwind, .. }
776+ | TerminatorKind :: DropAndReplace { ref place, unwind, .. } => {
777+ // If the place doesn't actually need dropping, treat it like a regular goto.
778+ let ty = self . instance . subst_mir ( tcx, & place. ty ( self . callee_body , tcx) . ty ) ;
779+ if ty. needs_drop ( tcx, self . param_env ) {
780+ self . cost += CALL_PENALTY ;
781+ if unwind. is_some ( ) {
782+ self . cost += LANDINGPAD_PENALTY ;
783+ }
784+ } else {
785+ self . cost += INSTR_COST ;
786+ }
787+ }
788+ TerminatorKind :: Call { func : Operand :: Constant ( ref f) , cleanup, .. } => {
789+ let fn_ty = self . instance . subst_mir ( tcx, & f. literal . ty ( ) ) ;
790+ self . cost += if let ty:: FnDef ( def_id, _) = * fn_ty. kind ( ) && tcx. is_intrinsic ( def_id) {
791+ // Don't give intrinsics the extra penalty for calls
792+ INSTR_COST
793+ } else {
794+ CALL_PENALTY
795+ } ;
796+ if cleanup. is_some ( ) {
797+ self . cost += LANDINGPAD_PENALTY ;
798+ }
799+ }
800+ TerminatorKind :: Assert { cleanup, .. } => {
801+ self . cost += CALL_PENALTY ;
802+ if cleanup. is_some ( ) {
803+ self . cost += LANDINGPAD_PENALTY ;
804+ }
805+ }
806+ TerminatorKind :: Resume => self . cost += RESUME_PENALTY ,
807+ TerminatorKind :: InlineAsm { cleanup, .. } => {
808+ self . cost += INSTR_COST ;
809+ if cleanup. is_some ( ) {
810+ self . cost += LANDINGPAD_PENALTY ;
811+ }
812+ }
813+ _ => self . cost += INSTR_COST ,
814+ }
815+
816+ self . super_terminator ( terminator, location) ;
817+ }
818+
819+ /// Count up the cost of local variables and temps, if we know the size
820+ /// use that, otherwise we use a moderately-large dummy cost.
821+ fn visit_local_decl ( & mut self , local : Local , local_decl : & LocalDecl < ' tcx > ) {
822+ let tcx = self . tcx ;
823+ let ptr_size = tcx. data_layout . pointer_size . bytes ( ) ;
824+
825+ let ty = self . instance . subst_mir ( tcx, & local_decl. ty ) ;
826+ // Cost of the var is the size in machine-words, if we know
827+ // it.
828+ if let Some ( size) = type_size_of ( tcx, self . param_env , ty) {
829+ self . cost += ( ( size + ptr_size - 1 ) / ptr_size) as usize ;
830+ } else {
831+ self . cost += UNKNOWN_SIZE_COST ;
832+ }
833+
834+ self . super_local_decl ( local, local_decl)
835+ }
836+
837+ /// This method duplicates code from MIR validation in an attempt to detect type mismatches due
838+ /// to normalization failure.
839+ fn visit_projection_elem (
840+ & mut self ,
841+ local : Local ,
842+ proj_base : & [ PlaceElem < ' tcx > ] ,
843+ elem : PlaceElem < ' tcx > ,
844+ context : PlaceContext ,
845+ location : Location ,
846+ ) {
847+ if let ProjectionElem :: Field ( f, ty) = elem {
848+ let parent = Place { local, projection : self . tcx . intern_place_elems ( proj_base) } ;
849+ let parent_ty = parent. ty ( & self . callee_body . local_decls , self . tcx ) ;
850+ let check_equal = |this : & mut Self , f_ty| {
851+ if !equal_up_to_regions ( this. tcx , this. param_env , ty, f_ty) {
852+ trace ! ( ?ty, ?f_ty) ;
853+ this. validation = Err ( "failed to normalize projection type" ) ;
854+ return ;
855+ }
856+ } ;
857+
858+ let kind = match parent_ty. ty . kind ( ) {
859+ & ty:: Opaque ( def_id, substs) => {
860+ self . tcx . bound_type_of ( def_id) . subst ( self . tcx , substs) . kind ( )
861+ }
862+ kind => kind,
863+ } ;
864+
865+ match kind {
866+ ty:: Tuple ( fields) => {
867+ let Some ( f_ty) = fields. get ( f. as_usize ( ) ) else {
868+ self . validation = Err ( "malformed MIR" ) ;
869+ return ;
870+ } ;
871+ check_equal ( self , * f_ty) ;
872+ }
873+ ty:: Adt ( adt_def, substs) => {
874+ let var = parent_ty. variant_index . unwrap_or ( VariantIdx :: from_u32 ( 0 ) ) ;
875+ let Some ( field) = adt_def. variant ( var) . fields . get ( f. as_usize ( ) ) else {
876+ self . validation = Err ( "malformed MIR" ) ;
877+ return ;
878+ } ;
879+ check_equal ( self , field. ty ( self . tcx , substs) ) ;
880+ }
881+ ty:: Closure ( _, substs) => {
882+ let substs = substs. as_closure ( ) ;
883+ let Some ( f_ty) = substs. upvar_tys ( ) . nth ( f. as_usize ( ) ) else {
884+ self . validation = Err ( "malformed MIR" ) ;
885+ return ;
886+ } ;
887+ check_equal ( self , f_ty) ;
888+ }
889+ & ty:: Generator ( def_id, substs, _) => {
890+ let f_ty = if let Some ( var) = parent_ty. variant_index {
891+ let gen_body = if def_id == self . callee_body . source . def_id ( ) {
892+ self . callee_body
893+ } else {
894+ self . tcx . optimized_mir ( def_id)
895+ } ;
896+
897+ let Some ( layout) = gen_body. generator_layout ( ) else {
898+ self . validation = Err ( "malformed MIR" ) ;
899+ return ;
900+ } ;
901+
902+ let Some ( & local) = layout. variant_fields [ var] . get ( f) else {
903+ self . validation = Err ( "malformed MIR" ) ;
904+ return ;
905+ } ;
906+
907+ let Some ( & f_ty) = layout. field_tys . get ( local) else {
908+ self . validation = Err ( "malformed MIR" ) ;
909+ return ;
910+ } ;
911+
912+ f_ty
913+ } else {
914+ let Some ( f_ty) = substs. as_generator ( ) . prefix_tys ( ) . nth ( f. index ( ) ) else {
915+ self . validation = Err ( "malformed MIR" ) ;
916+ return ;
917+ } ;
918+
919+ f_ty
920+ } ;
921+
922+ check_equal ( self , f_ty) ;
923+ }
924+ _ => self . validation = Err ( "malformed MIR" ) ,
925+ }
926+ }
927+
928+ self . super_projection_elem ( local, proj_base, elem, context, location) ;
929+ }
930+ }
931+
802932/**
803933 * Integrator.
804934 *
0 commit comments