Skip to content

Commit a353b5d

Browse files
committed
Cranelift: Use a fixpoint loop to compute the best value for each eclass
Fixes #7857
1 parent 02d1005 commit a353b5d

File tree

4 files changed

+189
-55
lines changed

4 files changed

+189
-55
lines changed

cranelift/codegen/src/egraph.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,4 +701,5 @@ pub(crate) struct Stats {
701701
pub(crate) elaborate_func: u64,
702702
pub(crate) elaborate_func_pre_insts: u64,
703703
pub(crate) elaborate_func_post_insts: u64,
704+
pub(crate) elaborate_best_cost_fixpoint_iters: u64,
704705
}

cranelift/codegen/src/egraph/cost.rs

Lines changed: 61 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ impl Cost {
7474
const DEPTH_BITS: u8 = 8;
7575
const DEPTH_MASK: u32 = (1 << Self::DEPTH_BITS) - 1;
7676
const OP_COST_MASK: u32 = !Self::DEPTH_MASK;
77-
const MAX_OP_COST: u32 = (Self::OP_COST_MASK >> Self::DEPTH_BITS) - 1;
77+
const MAX_OP_COST: u32 = Self::OP_COST_MASK >> Self::DEPTH_BITS;
7878

7979
pub(crate) fn infinity() -> Cost {
8080
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
@@ -86,14 +86,24 @@ impl Cost {
8686
Cost(0)
8787
}
8888

89-
/// Construct a new finite cost from the given parts.
89+
pub(crate) fn is_infinite(&self) -> bool {
90+
*self == Cost::infinity()
91+
}
92+
93+
pub(crate) fn is_finite(&self) -> bool {
94+
!self.is_infinite()
95+
}
96+
97+
/// Construct a new `Cost` from the given parts.
9098
///
91-
/// The opcode cost is clamped to the maximum value representable.
92-
fn new_finite(opcode_cost: u32, depth: u8) -> Cost {
93-
let opcode_cost = std::cmp::min(opcode_cost, Self::MAX_OP_COST);
94-
let cost = Cost((opcode_cost << Self::DEPTH_BITS) | u32::from(depth));
95-
debug_assert_ne!(cost, Cost::infinity());
96-
cost
99+
/// If the opcode cost is greater than or equal to the maximum representable
100+
/// opcode cost, then the resulting `Cost` saturates to infinity.
101+
fn new(opcode_cost: u32, depth: u8) -> Cost {
102+
if opcode_cost >= Self::MAX_OP_COST {
103+
Self::infinity()
104+
} else {
105+
Cost(opcode_cost << Self::DEPTH_BITS | u32::from(depth))
106+
}
97107
}
98108

99109
fn depth(&self) -> u8 {
@@ -111,7 +121,7 @@ impl Cost {
111121
/// that satisfies `inst_predicates::is_pure_for_egraph()`.
112122
pub(crate) fn of_pure_op(op: Opcode, operand_costs: impl IntoIterator<Item = Self>) -> Self {
113123
let c = pure_op_cost(op) + operand_costs.into_iter().sum();
114-
Cost::new_finite(c.op_cost(), c.depth().saturating_add(1))
124+
Cost::new(c.op_cost(), c.depth().saturating_add(1))
115125
}
116126
}
117127

@@ -131,12 +141,9 @@ impl std::ops::Add<Cost> for Cost {
131141
type Output = Cost;
132142

133143
fn add(self, other: Cost) -> Cost {
134-
let op_cost = std::cmp::min(
135-
self.op_cost().saturating_add(other.op_cost()),
136-
Self::MAX_OP_COST,
137-
);
144+
let op_cost = self.op_cost().saturating_add(other.op_cost());
138145
let depth = std::cmp::max(self.depth(), other.depth());
139-
Cost::new_finite(op_cost, depth)
146+
Cost::new(op_cost, depth)
140147
}
141148
}
142149

@@ -147,11 +154,11 @@ impl std::ops::Add<Cost> for Cost {
147154
fn pure_op_cost(op: Opcode) -> Cost {
148155
match op {
149156
// Constants.
150-
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new_finite(1, 0),
157+
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost::new(1, 0),
151158

152159
// Extends/reduces.
153160
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
154-
Cost::new_finite(2, 0)
161+
Cost::new(2, 0)
155162
}
156163

157164
// "Simple" arithmetic.
@@ -163,9 +170,45 @@ fn pure_op_cost(op: Opcode) -> Cost {
163170
| Opcode::Bnot
164171
| Opcode::Ishl
165172
| Opcode::Ushr
166-
| Opcode::Sshr => Cost::new_finite(3, 0),
173+
| Opcode::Sshr => Cost::new(3, 0),
167174

168175
// Everything else (pure.)
169-
_ => Cost::new_finite(4, 0),
176+
_ => Cost::new(4, 0),
177+
}
178+
}
179+
180+
#[cfg(test)]
181+
mod tests {
182+
use super::*;
183+
184+
#[test]
185+
fn add_cost() {
186+
let a = Cost::new(5, 2);
187+
let b = Cost::new(37, 3);
188+
assert_eq!(a + b, Cost::new(42, 3));
189+
}
190+
191+
#[test]
192+
fn add_infinity() {
193+
let a = Cost::new(5, 2);
194+
let b = Cost::infinity();
195+
assert_eq!(a + b, Cost::infinity());
196+
}
197+
198+
#[test]
199+
fn op_cost_saturates_to_infinity() {
200+
let a = Cost::new(Cost::MAX_OP_COST - 10, 2);
201+
let b = Cost::new(11, 2);
202+
assert_eq!(a + b, Cost::infinity());
203+
}
204+
205+
#[test]
206+
fn depth_saturates_to_max_depth() {
207+
let a = Cost::new(10, u8::MAX);
208+
let b = Cost::new(10, 1);
209+
assert_eq!(
210+
Cost::of_pure_op(Opcode::Iconst, [a, b]),
211+
Cost::new(21, u8::MAX)
212+
);
170213
}
171214
}

cranelift/codegen/src/egraph/elaborate.rs

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use super::Stats;
77
use crate::dominator_tree::DominatorTree;
88
use crate::fx::{FxHashMap, FxHashSet};
99
use crate::hash_map::Entry as HashEntry;
10+
use crate::inst_predicates::is_pure_for_egraph;
1011
use crate::ir::{Block, Function, Inst, Value, ValueDef};
1112
use crate::loop_analysis::{Loop, LoopAnalysis};
1213
use crate::scoped_hash_map::ScopedHashMap;
@@ -216,46 +217,92 @@ impl<'a> Elaborator<'a> {
216217

217218
fn compute_best_values(&mut self) {
218219
let best = &mut self.value_to_best_value;
219-
for (value, def) in self.func.dfg.values_and_defs() {
220-
trace!("computing best for value {:?} def {:?}", value, def);
221-
match def {
222-
ValueDef::Union(x, y) => {
223-
// Pick the best of the two options based on
224-
// min-cost. This works because each element of `best`
225-
// is a `(cost, value)` tuple; `cost` comes first so
226-
// the natural comparison works based on cost, and
227-
// breaks ties based on value number.
228-
trace!(" -> best of {:?} and {:?}", best[x], best[y]);
229-
best[value] = std::cmp::min(best[x], best[y]);
230-
trace!(" -> {:?}", best[value]);
231-
}
232-
ValueDef::Param(_, _) => {
233-
best[value] = BestEntry(Cost::zero(), value);
220+
221+
// Do a fixpoint loop to compute the best value for each eclass.
222+
//
223+
// The maximum number of iterations is the length of the longest chain
224+
// of `vNN -> vMM` edges in the dataflow graph where `NN < MM`, so this
225+
// is *technically* quadratic, but `cranelift-frontend` won't construct
226+
// any such edges. NaN canonicalization will introduce some of these
227+
// edges, but they are chains of only two or three edges. So in
228+
// practice, we *never* do more than a handful of iterations here unless
229+
// (a) we parsed the CLIF from text and the text was funkily numbered,
230+
// which we don't really care about, or (b) the CLIF producer did
231+
// something weird, in which case it is their responsibility to stop
232+
// doing that.
233+
trace!("Entering fixpoint loop to compute the best values for each eclass");
234+
let mut keep_going = true;
235+
while keep_going {
236+
keep_going = false;
237+
trace!(
238+
"fixpoint iteration {}",
239+
self.stats.elaborate_best_cost_fixpoint_iters
240+
);
241+
self.stats.elaborate_best_cost_fixpoint_iters += 1;
242+
243+
for (value, def) in self.func.dfg.values_and_defs() {
244+
// If the cost of this value is finite, then we've already found
245+
// its final cost.
246+
if best[value].0.is_finite() {
247+
continue;
234248
}
235-
// If the Inst is inserted into the layout (which is,
236-
// at this point, only the side-effecting skeleton),
237-
// then it must be computed and thus we give it zero
238-
// cost.
239-
ValueDef::Result(inst, _) => {
240-
if let Some(_) = self.func.layout.inst_block(inst) {
241-
best[value] = BestEntry(Cost::zero(), value);
242-
} else {
243-
trace!(" -> value {}: result, computing cost", value);
244-
let inst_data = &self.func.dfg.insts[inst];
245-
// N.B.: at this point we know that the opcode is
246-
// pure, so `pure_op_cost`'s precondition is
247-
// satisfied.
248-
let cost = Cost::of_pure_op(
249-
inst_data.opcode(),
250-
self.func.dfg.inst_values(inst).map(|value| best[value].0),
249+
250+
trace!("computing best for value {:?} def {:?}", value, def);
251+
let orig_best_value = best[value];
252+
253+
match def {
254+
ValueDef::Union(x, y) => {
255+
// Pick the best of the two options based on
256+
// min-cost. This works because each element of `best`
257+
// is a `(cost, value)` tuple; `cost` comes first so
258+
// the natural comparison works based on cost, and
259+
// breaks ties based on value number.
260+
best[value] = std::cmp::min(best[x], best[y]);
261+
trace!(
262+
" -> best of union({:?}, {:?}) = {:?}",
263+
best[x],
264+
best[y],
265+
best[value]
251266
);
252-
best[value] = BestEntry(cost, value);
253267
}
254-
}
255-
};
256-
debug_assert_ne!(best[value].0, Cost::infinity());
257-
debug_assert_ne!(best[value].1, Value::reserved_value());
258-
trace!("best for eclass {:?}: {:?}", value, best[value]);
268+
ValueDef::Param(_, _) => {
269+
best[value] = BestEntry(Cost::zero(), value);
270+
}
271+
// If the Inst is inserted into the layout (which is,
272+
// at this point, only the side-effecting skeleton),
273+
// then it must be computed and thus we give it zero
274+
// cost.
275+
ValueDef::Result(inst, _) => {
276+
if let Some(_) = self.func.layout.inst_block(inst) {
277+
best[value] = BestEntry(Cost::zero(), value);
278+
} else {
279+
let inst_data = &self.func.dfg.insts[inst];
280+
// N.B.: at this point we know that the opcode is
281+
// pure, so `pure_op_cost`'s precondition is
282+
// satisfied.
283+
let cost = Cost::of_pure_op(
284+
inst_data.opcode(),
285+
self.func.dfg.inst_values(inst).map(|value| best[value].0),
286+
);
287+
best[value] = BestEntry(cost, value);
288+
trace!(" -> cost of value {} = {:?}", value, cost);
289+
}
290+
}
291+
};
292+
293+
// Keep on iterating the fixpoint loop while we are finding new
294+
// best values.
295+
keep_going |= orig_best_value != best[value];
296+
}
297+
}
298+
299+
if cfg!(any(feature = "trace-log", debug_assertions)) {
300+
trace!("finished fixpoint loop to compute best value for each eclass");
301+
for value in self.func.dfg.values() {
302+
debug_assert_ne!(best[value].0, Cost::infinity());
303+
debug_assert_ne!(best[value].1, Value::reserved_value());
304+
trace!("-> best for eclass {:?}: {:?}", value, best[value]);
305+
}
259306
}
260307
}
261308

@@ -606,7 +653,13 @@ impl<'a> Elaborator<'a> {
606653
}
607654
inst
608655
};
656+
609657
// Place the inst just before `before`.
658+
debug_assert!(
659+
is_pure_for_egraph(self.func, inst),
660+
"something has gone very wrong if we are elaborating effectful \
661+
instructions, they should have remained in the skeleton"
662+
);
610663
self.func.layout.insert_inst(inst, before);
611664

612665
// Update the inst's arguments.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
test optimize
2+
set enable_verifier=true
3+
set opt_level=speed
4+
target x86_64
5+
6+
;; This test case should optimize just fine, and should definitely not produce
7+
;; CLIF that has verifier errors like
8+
;;
9+
;; error: inst10 (v12 = select.f32 v11, v4, v10 ; v11 = 1): uses value arg
10+
;; from non-dominating block4
11+
12+
function %foo() {
13+
block0:
14+
v0 = iconst.i64 0
15+
v2 = f32const 0.0
16+
v9 = f32const 0.0
17+
v20 = fneg v2
18+
v18 = fcmp eq v20, v20
19+
v4 = select v18, v2, v20
20+
v8 = iconst.i32 0
21+
v11 = iconst.i32 1
22+
brif v0, block2, block3
23+
24+
block2:
25+
brif.i32 v8, block4(v2), block4(v9)
26+
27+
block4(v10: f32):
28+
v12 = select.f32 v11, v4, v10
29+
v13 = bitcast.i32 v12
30+
store v13, v0
31+
trap user0
32+
33+
block3:
34+
v15 = bitcast.i32 v4
35+
store v15, v0
36+
return
37+
}

0 commit comments

Comments
 (0)