From 1834f5a272d567a714f78c7f48c0d3ae4a6238bb Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Wed, 26 Jun 2024 19:13:38 +0000 Subject: [PATCH 1/2] Swap encapsulation of DCP state. --- .../rustc_mir_dataflow/src/value_analysis.rs | 149 ++++++++++-------- 1 file changed, 81 insertions(+), 68 deletions(-) diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index bfbfff7e25944..0364c23bfcbe4 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -39,7 +39,7 @@ use std::ops::Range; use rustc_data_structures::fx::FxHashMap; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_index::bit_set::BitSet; -use rustc_index::{IndexSlice, IndexVec}; +use rustc_index::IndexVec; use rustc_middle::bug; use rustc_middle::mir::visit::{MutatingUseContext, PlaceContext, Visitor}; use rustc_middle::mir::*; @@ -336,14 +336,14 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper const NAME: &'static str = T::NAME; fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain { - State(StateData::Unreachable) + State::Unreachable } fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥. - assert!(matches!(state.0, StateData::Unreachable)); - let values = IndexVec::from_elem_n(T::Value::BOTTOM, self.0.map().value_count); - *state = State(StateData::Reachable(values)); + assert!(matches!(state, State::Unreachable)); + let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count); + *state = State::Reachable(values); for arg in body.args_iter() { state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map()); } @@ -415,27 +415,30 @@ rustc_index::newtype_index!( /// See [`State`]. #[derive(PartialEq, Eq, Debug)] -enum StateData { - Reachable(IndexVec), - Unreachable, +struct StateData { + map: IndexVec, +} + +impl StateData { + fn from_elem_n(elem: V, n: usize) -> StateData { + StateData { map: IndexVec::from_elem_n(elem, n) } + } } impl Clone for StateData { fn clone(&self) -> Self { - match self { - Self::Reachable(x) => Self::Reachable(x.clone()), - Self::Unreachable => Self::Unreachable, - } + StateData { map: self.map.clone() } } fn clone_from(&mut self, source: &Self) { - match (&mut *self, source) { - (Self::Reachable(x), Self::Reachable(y)) => { - // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`. - x.raw.clone_from(&y.raw); - } - _ => *self = source.clone(), - } + // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`. + self.map.raw.clone_from(&source.map.raw) + } +} + +impl JoinSemiLattice for StateData { + fn join(&mut self, other: &Self) -> bool { + self.map.join(&other.map) } } @@ -450,33 +453,43 @@ impl Clone for StateData { /// /// Flooding means assigning a value (by default `⊤`) to all tracked projections of a given place. #[derive(PartialEq, Eq, Debug)] -pub struct State(StateData); +pub enum State { + Unreachable, + Reachable(StateData), +} impl Clone for State { fn clone(&self) -> Self { - Self(self.0.clone()) + match self { + Self::Reachable(x) => Self::Reachable(x.clone()), + Self::Unreachable => Self::Unreachable, + } } fn clone_from(&mut self, source: &Self) { - self.0.clone_from(&source.0); + match (&mut *self, source) { + (Self::Reachable(x), Self::Reachable(y)) => { + x.clone_from(&y); + } + _ => *self = source.clone(), + } } } impl State { pub fn new(init: V, map: &Map) -> State { - let values = IndexVec::from_elem_n(init, map.value_count); - State(StateData::Reachable(values)) + State::Reachable(StateData::from_elem_n(init, map.value_count)) } pub fn all(&self, f: impl Fn(&V) -> bool) -> bool { - match self.0 { - StateData::Unreachable => true, - StateData::Reachable(ref values) => values.iter().all(f), + match self { + State::Unreachable => true, + State::Reachable(ref values) => values.map.iter().all(f), } } fn is_reachable(&self) -> bool { - matches!(&self.0, StateData::Reachable(_)) + matches!(self, State::Reachable(_)) } /// Assign `value` to all places that are contained in `place` or may alias one. @@ -519,9 +532,9 @@ impl State { map: &Map, value: V, ) { - let StateData::Reachable(values) = &mut self.0 else { return }; + let State::Reachable(values) = self else { return }; map.for_each_aliasing_place(place, tail_elem, &mut |vi| { - values[vi] = value.clone(); + values.map[vi] = value.clone(); }); } @@ -541,9 +554,9 @@ impl State { /// /// The target place must have been flooded before calling this method. pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) { - let StateData::Reachable(values) = &mut self.0 else { return }; + let State::Reachable(values) = self else { return }; if let Some(value_index) = map.places[target].value_index { - values[value_index] = value; + values.map[value_index] = value; } } @@ -555,14 +568,14 @@ impl State { /// /// The target place must have been flooded before calling this method. pub fn insert_place_idx(&mut self, target: PlaceIndex, source: PlaceIndex, map: &Map) { - let StateData::Reachable(values) = &mut self.0 else { return }; + let State::Reachable(values) = self else { return }; // If both places are tracked, we copy the value to the target. // If the target is tracked, but the source is not, we do nothing, as invalidation has // already been performed. if let Some(target_value) = map.places[target].value_index { if let Some(source_value) = map.places[source].value_index { - values[target_value] = values[source_value].clone(); + values.map[target_value] = values.map[source_value].clone(); } } for target_child in map.children(target) { @@ -616,11 +629,11 @@ impl State { /// Retrieve the value stored for a place index, or `None` if it is not tracked. pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option { - match &self.0 { - StateData::Reachable(values) => { - map.places[place].value_index.map(|v| values[v].clone()) + match self { + State::Reachable(values) => { + map.places[place].value_index.map(|v| values.map[v].clone()) } - StateData::Unreachable => None, + State::Unreachable => None, } } @@ -631,10 +644,10 @@ impl State { where V: HasBottom + HasTop, { - match &self.0 { - StateData::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP), + match self { + State::Reachable(_) => self.try_get(place, map).unwrap_or(V::TOP), // Because this is unreachable, we can return any value we want. - StateData::Unreachable => V::BOTTOM, + State::Unreachable => V::BOTTOM, } } @@ -645,10 +658,10 @@ impl State { where V: HasBottom + HasTop, { - match &self.0 { - StateData::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP), + match self { + State::Reachable(_) => self.try_get_discr(place, map).unwrap_or(V::TOP), // Because this is unreachable, we can return any value we want. - StateData::Unreachable => V::BOTTOM, + State::Unreachable => V::BOTTOM, } } @@ -659,10 +672,10 @@ impl State { where V: HasBottom + HasTop, { - match &self.0 { - StateData::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP), + match self { + State::Reachable(_) => self.try_get_len(place, map).unwrap_or(V::TOP), // Because this is unreachable, we can return any value we want. - StateData::Unreachable => V::BOTTOM, + State::Unreachable => V::BOTTOM, } } @@ -673,11 +686,11 @@ impl State { where V: HasBottom + HasTop, { - match &self.0 { - StateData::Reachable(values) => { - map.places[place].value_index.map(|v| values[v].clone()).unwrap_or(V::TOP) + match self { + State::Reachable(values) => { + map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP) } - StateData::Unreachable => { + State::Unreachable => { // Because this is unreachable, we can return any value we want. V::BOTTOM } @@ -687,13 +700,13 @@ impl State { impl JoinSemiLattice for State { fn join(&mut self, other: &Self) -> bool { - match (&mut self.0, &other.0) { - (_, StateData::Unreachable) => false, - (StateData::Unreachable, _) => { + match (&mut *self, other) { + (_, State::Unreachable) => false, + (State::Unreachable, _) => { *self = other.clone(); true } - (StateData::Reachable(this), StateData::Reachable(other)) => this.join(other), + (State::Reachable(this), State::Reachable(ref other)) => this.join(other), } } } @@ -1194,9 +1207,9 @@ where T::Value: Debug, { fn fmt_with(&self, ctxt: &ValueAnalysisWrapper, f: &mut Formatter<'_>) -> std::fmt::Result { - match &self.0 { - StateData::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f), - StateData::Unreachable => write!(f, "unreachable"), + match self { + State::Reachable(values) => debug_with_context(values, None, ctxt.0.map(), f), + State::Unreachable => write!(f, "unreachable"), } } @@ -1206,8 +1219,8 @@ where ctxt: &ValueAnalysisWrapper, f: &mut Formatter<'_>, ) -> std::fmt::Result { - match (&self.0, &old.0) { - (StateData::Reachable(this), StateData::Reachable(old)) => { + match (self, old) { + (State::Reachable(this), State::Reachable(old)) => { debug_with_context(this, Some(old), ctxt.0.map(), f) } _ => Ok(()), // Consider printing something here. @@ -1218,18 +1231,18 @@ where fn debug_with_context_rec( place: PlaceIndex, place_str: &str, - new: &IndexSlice, - old: Option<&IndexSlice>, + new: &StateData, + old: Option<&StateData>, map: &Map, f: &mut Formatter<'_>, ) -> std::fmt::Result { if let Some(value) = map.places[place].value_index { match old { - None => writeln!(f, "{}: {:?}", place_str, new[value])?, + None => writeln!(f, "{}: {:?}", place_str, new.map[value])?, Some(old) => { - if new[value] != old[value] { - writeln!(f, "\u{001f}-{}: {:?}", place_str, old[value])?; - writeln!(f, "\u{001f}+{}: {:?}", place_str, new[value])?; + if new.map[value] != old.map[value] { + writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?; + writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?; } } } @@ -1262,8 +1275,8 @@ fn debug_with_context_rec( } fn debug_with_context( - new: &IndexSlice, - old: Option<&IndexSlice>, + new: &StateData, + old: Option<&StateData>, map: &Map, f: &mut Formatter<'_>, ) -> std::fmt::Result { From 76244d4dbc768e15e429c1f66ec021884f369f5f Mon Sep 17 00:00:00 2001 From: Camille GILLOT Date: Wed, 26 Jun 2024 21:57:59 +0000 Subject: [PATCH 2/2] Make jump threading state sparse. --- .../src/framework/lattice.rs | 14 +++ .../rustc_mir_dataflow/src/value_analysis.rs | 91 ++++++++++++------- .../rustc_mir_transform/src/jump_threading.rs | 19 +++- 3 files changed, 86 insertions(+), 38 deletions(-) diff --git a/compiler/rustc_mir_dataflow/src/framework/lattice.rs b/compiler/rustc_mir_dataflow/src/framework/lattice.rs index 1c2b475a43c96..23738f7a4a538 100644 --- a/compiler/rustc_mir_dataflow/src/framework/lattice.rs +++ b/compiler/rustc_mir_dataflow/src/framework/lattice.rs @@ -76,6 +76,8 @@ pub trait MeetSemiLattice: Eq { /// A set that has a "bottom" element, which is less than or equal to any other element. pub trait HasBottom { const BOTTOM: Self; + + fn is_bottom(&self) -> bool; } /// A set that has a "top" element, which is greater than or equal to any other element. @@ -114,6 +116,10 @@ impl MeetSemiLattice for bool { impl HasBottom for bool { const BOTTOM: Self = false; + + fn is_bottom(&self) -> bool { + !self + } } impl HasTop for bool { @@ -267,6 +273,10 @@ impl MeetSemiLattice for FlatSet { impl HasBottom for FlatSet { const BOTTOM: Self = Self::Bottom; + + fn is_bottom(&self) -> bool { + matches!(self, Self::Bottom) + } } impl HasTop for FlatSet { @@ -291,6 +301,10 @@ impl MaybeReachable { impl HasBottom for MaybeReachable { const BOTTOM: Self = MaybeReachable::Unreachable; + + fn is_bottom(&self) -> bool { + matches!(self, Self::Unreachable) + } } impl HasTop for MaybeReachable { diff --git a/compiler/rustc_mir_dataflow/src/value_analysis.rs b/compiler/rustc_mir_dataflow/src/value_analysis.rs index 0364c23bfcbe4..7c1ff6fda53f7 100644 --- a/compiler/rustc_mir_dataflow/src/value_analysis.rs +++ b/compiler/rustc_mir_dataflow/src/value_analysis.rs @@ -36,7 +36,7 @@ use std::collections::VecDeque; use std::fmt::{Debug, Formatter}; use std::ops::Range; -use rustc_data_structures::fx::FxHashMap; +use rustc_data_structures::fx::{FxHashMap, StdEntry}; use rustc_data_structures::stack::ensure_sufficient_stack; use rustc_index::bit_set::BitSet; use rustc_index::IndexVec; @@ -342,8 +342,7 @@ impl<'tcx, T: ValueAnalysis<'tcx>> AnalysisDomain<'tcx> for ValueAnalysisWrapper fn initialize_start_block(&self, body: &Body<'tcx>, state: &mut Self::Domain) { // The initial state maps all tracked places of argument projections to ⊤ and the rest to ⊥. assert!(matches!(state, State::Unreachable)); - let values = StateData::from_elem_n(T::Value::BOTTOM, self.0.map().value_count); - *state = State::Reachable(values); + *state = State::new_reachable(); for arg in body.args_iter() { state.flood(PlaceRef { local: arg, projection: &[] }, self.0.map()); } @@ -415,30 +414,54 @@ rustc_index::newtype_index!( /// See [`State`]. #[derive(PartialEq, Eq, Debug)] -struct StateData { - map: IndexVec, +pub struct StateData { + bottom: V, + /// This map only contains values that are not `⊥`. + map: FxHashMap, } -impl StateData { - fn from_elem_n(elem: V, n: usize) -> StateData { - StateData { map: IndexVec::from_elem_n(elem, n) } +impl StateData { + fn new() -> StateData { + StateData { bottom: V::BOTTOM, map: FxHashMap::default() } + } + + fn get(&self, idx: ValueIndex) -> &V { + self.map.get(&idx).unwrap_or(&self.bottom) + } + + fn insert(&mut self, idx: ValueIndex, elem: V) { + if elem.is_bottom() { + self.map.remove(&idx); + } else { + self.map.insert(idx, elem); + } } } impl Clone for StateData { fn clone(&self) -> Self { - StateData { map: self.map.clone() } + StateData { bottom: self.bottom.clone(), map: self.map.clone() } } fn clone_from(&mut self, source: &Self) { - // We go through `raw` here, because `IndexVec` currently has a naive `clone_from`. - self.map.raw.clone_from(&source.map.raw) + self.map.clone_from(&source.map) } } -impl JoinSemiLattice for StateData { +impl JoinSemiLattice for StateData { fn join(&mut self, other: &Self) -> bool { - self.map.join(&other.map) + let mut changed = false; + #[allow(rustc::potential_query_instability)] + for (i, v) in other.map.iter() { + match self.map.entry(*i) { + StdEntry::Vacant(e) => { + e.insert(v.clone()); + changed = true + } + StdEntry::Occupied(e) => changed |= e.into_mut().join(v), + } + } + changed } } @@ -476,15 +499,19 @@ impl Clone for State { } } -impl State { - pub fn new(init: V, map: &Map) -> State { - State::Reachable(StateData::from_elem_n(init, map.value_count)) +impl State { + pub fn new_reachable() -> State { + State::Reachable(StateData::new()) } - pub fn all(&self, f: impl Fn(&V) -> bool) -> bool { + pub fn all_bottom(&self) -> bool { match self { - State::Unreachable => true, - State::Reachable(ref values) => values.map.iter().all(f), + State::Unreachable => false, + State::Reachable(ref values) => + { + #[allow(rustc::potential_query_instability)] + values.map.values().all(V::is_bottom) + } } } @@ -533,9 +560,7 @@ impl State { value: V, ) { let State::Reachable(values) = self else { return }; - map.for_each_aliasing_place(place, tail_elem, &mut |vi| { - values.map[vi] = value.clone(); - }); + map.for_each_aliasing_place(place, tail_elem, &mut |vi| values.insert(vi, value.clone())); } /// Low-level method that assigns to a place. @@ -556,7 +581,7 @@ impl State { pub fn insert_value_idx(&mut self, target: PlaceIndex, value: V, map: &Map) { let State::Reachable(values) = self else { return }; if let Some(value_index) = map.places[target].value_index { - values.map[value_index] = value; + values.insert(value_index, value) } } @@ -575,7 +600,7 @@ impl State { // already been performed. if let Some(target_value) = map.places[target].value_index { if let Some(source_value) = map.places[source].value_index { - values.map[target_value] = values.map[source_value].clone(); + values.insert(target_value, values.get(source_value).clone()); } } for target_child in map.children(target) { @@ -631,7 +656,7 @@ impl State { pub fn try_get_idx(&self, place: PlaceIndex, map: &Map) -> Option { match self { State::Reachable(values) => { - map.places[place].value_index.map(|v| values.map[v].clone()) + map.places[place].value_index.map(|v| values.get(v).clone()) } State::Unreachable => None, } @@ -688,7 +713,7 @@ impl State { { match self { State::Reachable(values) => { - map.places[place].value_index.map(|v| values.map[v].clone()).unwrap_or(V::TOP) + map.places[place].value_index.map(|v| values.get(v).clone()).unwrap_or(V::TOP) } State::Unreachable => { // Because this is unreachable, we can return any value we want. @@ -698,7 +723,7 @@ impl State { } } -impl JoinSemiLattice for State { +impl JoinSemiLattice for State { fn join(&mut self, other: &Self) -> bool { match (&mut *self, other) { (_, State::Unreachable) => false, @@ -1228,7 +1253,7 @@ where } } -fn debug_with_context_rec( +fn debug_with_context_rec( place: PlaceIndex, place_str: &str, new: &StateData, @@ -1238,11 +1263,11 @@ fn debug_with_context_rec( ) -> std::fmt::Result { if let Some(value) = map.places[place].value_index { match old { - None => writeln!(f, "{}: {:?}", place_str, new.map[value])?, + None => writeln!(f, "{}: {:?}", place_str, new.get(value))?, Some(old) => { - if new.map[value] != old.map[value] { - writeln!(f, "\u{001f}-{}: {:?}", place_str, old.map[value])?; - writeln!(f, "\u{001f}+{}: {:?}", place_str, new.map[value])?; + if new.get(value) != old.get(value) { + writeln!(f, "\u{001f}-{}: {:?}", place_str, old.get(value))?; + writeln!(f, "\u{001f}+{}: {:?}", place_str, new.get(value))?; } } } @@ -1274,7 +1299,7 @@ fn debug_with_context_rec( Ok(()) } -fn debug_with_context( +fn debug_with_context( new: &StateData, old: Option<&StateData>, map: &Map, diff --git a/compiler/rustc_mir_transform/src/jump_threading.rs b/compiler/rustc_mir_transform/src/jump_threading.rs index 27e506a920bce..85a61f95ca36f 100644 --- a/compiler/rustc_mir_transform/src/jump_threading.rs +++ b/compiler/rustc_mir_transform/src/jump_threading.rs @@ -47,6 +47,7 @@ use rustc_middle::mir::visit::Visitor; use rustc_middle::mir::*; use rustc_middle::ty::layout::LayoutOf; use rustc_middle::ty::{self, ScalarInt, TyCtxt}; +use rustc_mir_dataflow::lattice::HasBottom; use rustc_mir_dataflow::value_analysis::{Map, PlaceIndex, State, TrackElem}; use rustc_span::DUMMY_SP; use rustc_target::abi::{TagEncoding, Variants}; @@ -158,9 +159,17 @@ impl Condition { } } -#[derive(Copy, Clone, Debug, Default)] +#[derive(Copy, Clone, Debug)] struct ConditionSet<'a>(&'a [Condition]); +impl HasBottom for ConditionSet<'_> { + const BOTTOM: Self = ConditionSet(&[]); + + fn is_bottom(&self) -> bool { + self.0.is_empty() + } +} + impl<'a> ConditionSet<'a> { fn iter(self) -> impl Iterator + 'a { self.0.iter().copied() @@ -177,7 +186,7 @@ impl<'a> ConditionSet<'a> { impl<'tcx, 'a> TOFinder<'tcx, 'a> { fn is_empty(&self, state: &State>) -> bool { - state.all(|cs| cs.0.is_empty()) + state.all_bottom() } /// Recursion entry point to find threading opportunities. @@ -198,7 +207,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { debug!(?discr); let cost = CostChecker::new(self.tcx, self.param_env, None, self.body); - let mut state = State::new(ConditionSet::default(), self.map); + let mut state = State::new_reachable(); let conds = if let Some((value, then, else_)) = targets.as_static_if() { let value = ScalarInt::try_from_uint(value, discr_layout.size)?; @@ -255,7 +264,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // _1 = 5 // Whatever happens here, it won't change the result of a `SwitchInt`. // _1 = 6 if let Some((lhs, tail)) = self.mutated_statement(stmt) { - state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::default()); + state.flood_with_tail_elem(lhs.as_ref(), tail, self.map, ConditionSet::BOTTOM); } } @@ -609,7 +618,7 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> { // We can recurse through this terminator. let mut state = state(); if let Some(place_to_flood) = place_to_flood { - state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::default()); + state.flood_with(place_to_flood.as_ref(), self.map, ConditionSet::BOTTOM); } self.find_opportunity(bb, state, cost.clone(), depth + 1); }