Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions cranelift/codegen/src/egraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ where
// The initial best choice is "no simplification, just use the original
// instruction" which has the original instruction's cost.
let mut best = None;
let mut best_cost = cost::Cost::of_skeleton_op(
let mut best_cost = cost::ScalarCost::of_skeleton_op(
ctx.func.dfg.insts[inst].opcode(),
ctx.func.dfg.inst_args(inst).len(),
);
Expand Down Expand Up @@ -682,7 +682,7 @@ where

// Our best simplification is the one with the least cost. Update
// `best` if necessary.
let cost = cost::Cost::of_skeleton_op(
let cost = cost::ScalarCost::of_skeleton_op(
ctx.func.dfg.insts[new_inst].opcode(),
ctx.func.dfg.inst_args(new_inst).len(),
);
Expand Down
211 changes: 152 additions & 59 deletions cranelift/codegen/src/egraph/cost.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,97 @@
//! Cost functions for egraph representation.

use crate::ir::Opcode;
use crate::ir::{DataFlowGraph, Inst, Opcode};
use cranelift_entity::ImmutableEntitySet;

/// The compound cost of an expression.
///
/// Tracks the set instructions that make up this expression and sums their
/// costs, avoiding "double counting" the costs of values that were defined by
/// the same instruction and values that appear multiple times within the
/// expression (i.e. the expression is a DAG and not a tree).
#[derive(Clone, Debug)]
pub(crate) struct ExprCost {
// The total cost of this expression.
total: ScalarCost,
// The set of instructions that must be evaluated to produce the associated
// expression.
insts: ImmutableEntitySet<Inst>,
}

impl Ord for ExprCost {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.total.cmp(&other.total)
}
}

impl PartialOrd for ExprCost {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
self.total.partial_cmp(&other.total)
}
}

impl PartialEq for ExprCost {
fn eq(&self, other: &Self) -> bool {
self.total == other.total
}
}

impl Eq for ExprCost {}

impl ExprCost {
/// Create an `ExprCost` with zero total cost and an empty set of
/// instructions.
pub fn zero() -> Self {
Self {
total: ScalarCost::zero(),
insts: ImmutableEntitySet::default(),
}
}

/// Create the cost for just the given instruction.
pub fn for_inst(dfg: &DataFlowGraph, inst: Inst) -> Self {
Self {
total: ScalarCost::of_opcode(dfg.insts[inst].opcode()),
insts: ImmutableEntitySet::unit(inst),
}
}

/// Add the other cost into this cost, unioning its set of instructions into
/// this cost's set, and only incrementing the total cost for new
/// instructions.
pub fn add(&mut self, dfg: &DataFlowGraph, other: &Self) {
match (self.insts.len(), other.insts.len()) {
// Nothing to do in this case.
(_, 0) => {}

// Clone `other` into `self` so that we reuse its set allocations.
(0, _) => {
*self = other.clone();
}

// Commute the addition so that we are (a) iterating over the
// smaller of the two sets, and (b) maximizing reuse of existing set
// allocations.
(a, b) if a < b => {
let mut other = other.clone();
for inst in self.insts.iter() {
if other.insts.insert(inst) {
other.total = other.total + ScalarCost::of_opcode(dfg.insts[inst].opcode());
}
}
*self = other;
}

_ => {
for inst in other.insts.iter() {
if self.insts.insert(inst) {
self.total = self.total + ScalarCost::of_opcode(dfg.insts[inst].opcode());
}
}
}
}
}
}

/// A cost of computing some value in the program.
///
Expand Down Expand Up @@ -31,11 +122,11 @@ use crate::ir::Opcode;
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, PartialEq, Eq)]
pub(crate) struct Cost(u32);
pub(crate) struct ScalarCost(u32);

impl core::fmt::Debug for Cost {
impl core::fmt::Debug for ScalarCost {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
if *self == Cost::infinity() {
if *self == ScalarCost::infinity() {
write!(f, "Cost::Infinite")
} else {
f.debug_struct("Cost::Finite")
Expand All @@ -46,7 +137,7 @@ impl core::fmt::Debug for Cost {
}
}

impl Ord for Cost {
impl Ord for ScalarCost {
#[inline]
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
// We make sure that the high bits are the op cost and the low bits are
Expand All @@ -63,38 +154,38 @@ impl Ord for Cost {
}
}

impl PartialOrd for Cost {
impl PartialOrd for ScalarCost {
#[inline]
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl Cost {
impl ScalarCost {
const DEPTH_BITS: u8 = 8;
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;

pub(crate) fn infinity() -> Cost {
pub(crate) fn infinity() -> ScalarCost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost(u32::MAX)
ScalarCost(u32::MAX)
}

pub(crate) fn zero() -> Cost {
Cost(0)
pub(crate) fn zero() -> ScalarCost {
ScalarCost(0)
}

/// Construct a new `Cost` from the given parts.
///
/// If the opcode cost is greater than or equal to the maximum representable
/// opcode cost, then the resulting `Cost` saturates to infinity.
fn new(opcode_cost: u32, depth: u8) -> Cost {
fn new(opcode_cost: u32, depth: u8) -> ScalarCost {
if opcode_cost >= Self::MAX_OP_COST {
Self::infinity()
} else {
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
ScalarCost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
}
}

Expand All @@ -108,17 +199,17 @@ impl Cost {
}

/// Return the cost of an opcode.
fn of_opcode(op: Opcode) -> Cost {
pub(crate) fn of_opcode(op: Opcode) -> ScalarCost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
Opcode::Iconst | Opcode::F32const | Opcode::F64const => ScalarCost::new(1, 0),

// Extends/reduces.
Opcode::Uextend
| Opcode::Sextend
| Opcode::Ireduce
| Opcode::Iconcat
| Opcode::Isplit => Cost::new(1, 0),
| Opcode::Isplit => ScalarCost::new(1, 0),

// "Simple" arithmetic.
Opcode::Iadd
Expand All @@ -129,110 +220,112 @@ impl Cost {
| Opcode::Bnot
| Opcode::Ishl
| Opcode::Ushr
| Opcode::Sshr => Cost::new(3, 0),
| Opcode::Sshr => ScalarCost::new(3, 0),

// "Expensive" arithmetic.
Opcode::Imul => Cost::new(10, 0),
Opcode::Imul => ScalarCost::new(10, 0),

// Everything else.
_ => {
// By default, be slightly more expensive than "simple"
// arithmetic.
let mut c = Cost::new(4, 0);
let mut c = ScalarCost::new(4, 0);

// And then get more expensive as the opcode does more side
// effects.
if op.can_trap() || op.other_side_effects() {
c = c + Cost::new(10, 0);
c = c + ScalarCost::new(10, 0);
}
if op.can_load() {
c = c + Cost::new(20, 0);
c = c + ScalarCost::new(20, 0);
}
if op.can_store() {
c = c + Cost::new(50, 0);
c = c + ScalarCost::new(50, 0);
}

c
}
}
}

/// Compute the cost of the operation and its given operands.
///
/// Caller is responsible for checking that the opcode came from an instruction
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
Cost::new(c.op_cost(), c.depth().saturating_add(1))
}

/// Compute the cost of an operation in the side-effectful skeleton.
pub(crate) fn of_skeleton_op(op: Opcode, arity: usize) -> Self {
Cost::of_opcode(op) + Cost::new(u32::try_from(arity).unwrap(), (arity != 0) as _)
ScalarCost::of_opcode(op)
+ ScalarCost::new(u32::try_from(arity).unwrap(), (arity != 0) as _)
}
}

impl core::iter::Sum<Cost> for Cost {
fn sum<I: Iterator<Item = Cost>>(iter: I) -> Self {
impl core::iter::Sum<ScalarCost> for ScalarCost {
fn sum<I: Iterator<Item = ScalarCost>>(iter: I) -> Self {
iter.fold(Self::zero(), |a, b| a + b)
}
}

impl core::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
impl core::default::Default for ScalarCost {
fn default() -> ScalarCost {
ScalarCost::zero()
}
}

impl core::ops::Add<Cost> for Cost {
type Output = Cost;
impl core::ops::Add<ScalarCost> for ScalarCost {
type Output = ScalarCost;

fn add(self, other: Cost) -> Cost {
fn add(self, other: ScalarCost) -> ScalarCost {
let op_cost = self.op_cost().saturating_add(other.op_cost());
let depth = core::cmp::max(self.depth(), other.depth());
Cost::new(op_cost, depth)
ScalarCost::new(op_cost, depth)
}
}

#[cfg(test)]
mod tests {
use super::*;

impl ScalarCost {
fn of_opcode_and_operands(
op: Opcode,
operand_costs: impl IntoIterator<Item = Self>,
) -> Self {
let c = Self::of_opcode(op) + operand_costs.into_iter().sum();
ScalarCost::new(c.op_cost(), c.depth().saturating_add(1))
}
}

#[test]
fn add_cost() {
let a = Cost::new(5, 2);
let b = Cost::new(37, 3);
assert_eq!(a + b, Cost::new(42, 3));
assert_eq!(b + a, Cost::new(42, 3));
let a = ScalarCost::new(5, 2);
let b = ScalarCost::new(37, 3);
assert_eq!(a + b, ScalarCost::new(42, 3));
assert_eq!(b + a, ScalarCost::new(42, 3));
}

#[test]
fn add_infinity() {
let a = Cost::new(5, 2);
let b = Cost::infinity();
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
let a = ScalarCost::new(5, 2);
let b = ScalarCost::infinity();
assert_eq!(a + b, ScalarCost::infinity());
assert_eq!(b + a, ScalarCost::infinity());
}

#[test]
fn op_cost_saturates_to_infinity() {
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
let b = Cost::new(11, 2);
assert_eq!(a + b, Cost::infinity());
assert_eq!(b + a, Cost::infinity());
let a = ScalarCost::new(ScalarCost::MAX_OP_COST - 10, 2);
let b = ScalarCost::new(11, 2);
assert_eq!(a + b, ScalarCost::infinity());
assert_eq!(b + a, ScalarCost::infinity());
}

#[test]
fn depth_saturates_to_max_depth() {
let a = Cost::new(10, u8::MAX);
let b = Cost::new(10, 1);
let a = ScalarCost::new(10, u8::MAX);
let b = ScalarCost::new(10, 1);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [a, b]),
Cost::new(21, u8::MAX)
ScalarCost::of_opcode_and_operands(Opcode::Iconst, [a, b]),
ScalarCost::new(21, u8::MAX)
);
assert_eq!(
Cost::of_pure_op(Opcode::Iconst, [b, a]),
Cost::new(21, u8::MAX)
ScalarCost::of_opcode_and_operands(Opcode::Iconst, [b, a]),
ScalarCost::new(21, u8::MAX)
);
}
}
Loading