Skip to content

Commit

Permalink
Give inlining bonuses to things that optimize out
Browse files Browse the repository at this point in the history
  • Loading branch information
scottmcm committed Jun 18, 2024
1 parent 916e0d2 commit 9afee03
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 64 deletions.
83 changes: 63 additions & 20 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use rustc_middle::bug;
use rustc_middle::mir::visit::*;
use rustc_middle::mir::*;
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
Expand All @@ -6,6 +7,8 @@ const INSTR_COST: usize = 5;
const CALL_PENALTY: usize = 25;
const LANDINGPAD_PENALTY: usize = 50;
const RESUME_PENALTY: usize = 45;
const LARGE_SWITCH_PENALTY: usize = 20;
const CONST_SWITCH_BONUS: usize = 10;

/// Verify that the callee body is compatible with the caller.
#[derive(Clone)]
Expand Down Expand Up @@ -42,36 +45,49 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
}

impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
// Don't count StorageLive/StorageDead in the inlining cost.
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
// Most costs are in rvalues and terminators, not in statements.
match statement.kind {
StatementKind::StorageLive(_)
| StatementKind::StorageDead(_)
| StatementKind::Deinit(_)
| StatementKind::Nop => {}
StatementKind::Intrinsic(ref ndi) => {
self.penalty += match **ndi {
NonDivergingIntrinsic::Assume(..) => INSTR_COST,
NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
};
}
_ => self.super_statement(statement, location),
}
}

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
match rvalue {
Rvalue::NullaryOp(NullOp::UbChecks, ..) if !self.tcx.sess.ub_checks() => {
// If this is in optimized MIR it's because it's used later,
// so if we don't need UB checks this session, give a bonus
// here to offset the cost of the call later.
self.bonus += CALL_PENALTY;
}
// These are essentially constants that didn't end up in an Operand,
// so treat them as also being free.
Rvalue::NullaryOp(..) => {}
_ => self.penalty += INSTR_COST,
}
}

fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
let tcx = self.tcx;
match terminator.kind {
TerminatorKind::Drop { ref place, unwind, .. } => {
match &terminator.kind {
TerminatorKind::Drop { place, unwind, .. } => {
// If the place doesn't actually need dropping, treat it like a regular goto.
let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
if ty.needs_drop(tcx, self.param_env) {
let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
if ty.needs_drop(self.tcx, self.param_env) {
self.penalty += CALL_PENALTY;
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
} else {
self.penalty += INSTR_COST;
}
}
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
let fn_ty = self.instantiate_ty(f.const_.ty());
self.penalty += if let ty::FnDef(def_id, _) = *fn_ty.kind()
&& tcx.intrinsic(def_id).is_some()
TerminatorKind::Call { func, unwind, .. } => {
self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
&& self.tcx.intrinsic(def_id).is_some()
{
// Don't give intrinsics the extra penalty for calls
INSTR_COST
Expand All @@ -82,8 +98,25 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
self.penalty += LANDINGPAD_PENALTY;
}
}
TerminatorKind::Assert { unwind, .. } => {
self.penalty += CALL_PENALTY;
TerminatorKind::SwitchInt { discr, targets } => {
if discr.constant().is_some() {
// Not only will this become a `Goto`, but likely other
// things will be removable as unreachable.
self.bonus += CONST_SWITCH_BONUS;
} else if targets.all_targets().len() > 3 {
// More than false/true/unreachable gets extra cost.
self.penalty += LARGE_SWITCH_PENALTY;
} else {
self.penalty += INSTR_COST;
}
}
TerminatorKind::Assert { unwind, msg, .. } => {
self.penalty +=
if msg.is_optional_overflow_check() && !self.tcx.sess.overflow_checks() {
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
Expand All @@ -95,7 +128,17 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
self.penalty += LANDINGPAD_PENALTY;
}
}
_ => self.penalty += INSTR_COST,
TerminatorKind::Unreachable => {
self.bonus += INSTR_COST;
}
TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
TerminatorKind::UnwindTerminate(..) => {}
kind @ (TerminatorKind::FalseUnwind { .. }
| TerminatorKind::FalseEdge { .. }
| TerminatorKind::Yield { .. }
| TerminatorKind::CoroutineDrop) => {
bug!("{kind:?} should not be in runtime MIR");
}
}
}
}
6 changes: 5 additions & 1 deletion tests/codegen/issues/issue-112509-slice-get-andthen-get.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

// CHECK-LABEL: @write_u8_variant_a
// CHECK-NEXT: {{.*}}:
// CHECK-NEXT: getelementptr
// CHECK-NEXT: icmp ugt
// CHECK-NEXT: getelementptr
// CHECK-NEXT: select i1 {{.+}} null
// CHECK-NEXT: insertvalue
// CHECK-NEXT: insertvalue
// CHECK-NEXT: ret
#[no_mangle]
pub fn write_u8_variant_a(bytes: &mut [u8], buf: u8, offset: usize) -> Option<&mut [u8]> {
let buf = buf.to_le_bytes();
Expand Down
4 changes: 4 additions & 0 deletions tests/crashes/123893.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ fn generic_impl<T>() {
impl<T> MagicTrait for T {
const IS_BIG: bool = true;
}
more_cost();
if T::IS_BIG {
big_impl::<i32>();
}
}

#[inline(never)]
fn big_impl<T>() {}

#[inline(never)]
fn more_cost() {}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,87 @@ fn step_forward(_1: u16, _2: usize) -> u16 {
debug x => _1;
debug n => _2;
let mut _0: u16;
scope 1 (inlined <u16 as Step>::forward) {
let mut _8: u16;
scope 2 {
}
scope 3 (inlined <u16 as Step>::forward_checked) {
scope 4 {
scope 6 (inlined core::num::<impl u16>::checked_add) {
let mut _7: bool;
scope 7 {
}
scope 8 (inlined core::num::<impl u16>::overflowing_add) {
let mut _5: (u16, bool);
let _6: bool;
scope 9 {
}
}
}
}
scope 5 (inlined convert::num::ptr_try_from_impls::<impl TryFrom<usize> for u16>::try_from) {
let mut _3: bool;
let mut _4: u16;
}
}
scope 10 (inlined Option::<u16>::is_none) {
scope 11 (inlined Option::<u16>::is_some) {
}
}
scope 12 (inlined core::num::<impl u16>::wrapping_add) {
}
}

bb0: {
_0 = <u16 as Step>::forward(move _1, move _2) -> [return: bb1, unwind unreachable];
StorageLive(_4);
StorageLive(_3);
_3 = Gt(_2, const 65535_usize);
switchInt(move _3) -> [0: bb1, otherwise: bb5];
}

bb1: {
_4 = _2 as u16 (IntToInt);
StorageDead(_3);
StorageLive(_6);
StorageLive(_5);
_5 = AddWithOverflow(_1, _4);
_6 = (_5.1: bool);
StorageDead(_5);
StorageLive(_7);
_7 = unlikely(move _6) -> [return: bb2, unwind unreachable];
}

bb2: {
switchInt(move _7) -> [0: bb3, otherwise: bb4];
}

bb3: {
StorageDead(_7);
StorageDead(_6);
goto -> bb7;
}

bb4: {
StorageDead(_7);
StorageDead(_6);
goto -> bb6;
}

bb5: {
StorageDead(_3);
goto -> bb6;
}

bb6: {
assert(!const true, "attempt to compute `{} + {}`, which would overflow", const core::num::<impl u16>::MAX, const 1_u16) -> [success: bb7, unwind unreachable];
}

bb7: {
StorageLive(_8);
_8 = _2 as u16 (IntToInt);
_0 = Add(_1, _8);
StorageDead(_8);
StorageDead(_4);
return;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,87 @@ fn step_forward(_1: u16, _2: usize) -> u16 {
debug x => _1;
debug n => _2;
let mut _0: u16;
scope 1 (inlined <u16 as Step>::forward) {
let mut _8: u16;
scope 2 {
}
scope 3 (inlined <u16 as Step>::forward_checked) {
scope 4 {
scope 6 (inlined core::num::<impl u16>::checked_add) {
let mut _7: bool;
scope 7 {
}
scope 8 (inlined core::num::<impl u16>::overflowing_add) {
let mut _5: (u16, bool);
let _6: bool;
scope 9 {
}
}
}
}
scope 5 (inlined convert::num::ptr_try_from_impls::<impl TryFrom<usize> for u16>::try_from) {
let mut _3: bool;
let mut _4: u16;
}
}
scope 10 (inlined Option::<u16>::is_none) {
scope 11 (inlined Option::<u16>::is_some) {
}
}
scope 12 (inlined core::num::<impl u16>::wrapping_add) {
}
}

bb0: {
_0 = <u16 as Step>::forward(move _1, move _2) -> [return: bb1, unwind continue];
StorageLive(_4);
StorageLive(_3);
_3 = Gt(_2, const 65535_usize);
switchInt(move _3) -> [0: bb1, otherwise: bb5];
}

bb1: {
_4 = _2 as u16 (IntToInt);
StorageDead(_3);
StorageLive(_6);
StorageLive(_5);
_5 = AddWithOverflow(_1, _4);
_6 = (_5.1: bool);
StorageDead(_5);
StorageLive(_7);
_7 = unlikely(move _6) -> [return: bb2, unwind unreachable];
}

bb2: {
switchInt(move _7) -> [0: bb3, otherwise: bb4];
}

bb3: {
StorageDead(_7);
StorageDead(_6);
goto -> bb7;
}

bb4: {
StorageDead(_7);
StorageDead(_6);
goto -> bb6;
}

bb5: {
StorageDead(_3);
goto -> bb6;
}

bb6: {
assert(!const true, "attempt to compute `{} + {}`, which would overflow", const core::num::<impl u16>::MAX, const 1_u16) -> [success: bb7, unwind continue];
}

bb7: {
StorageLive(_8);
_8 = _2 as u16 (IntToInt);
_0 = Add(_1, _8);
StorageDead(_8);
StorageDead(_4);
return;
}
}
Loading

0 comments on commit 9afee03

Please sign in to comment.