Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

coverage: Use a separate counter type and simplification step during counter creation #133849

Merged
merged 6 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 146 additions & 34 deletions compiler/rustc_mir_transform/src/coverage/counters.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::fmt::{self, Debug};

use rustc_data_structures::captures::Captures;
Expand All @@ -10,9 +11,12 @@ use tracing::{debug, debug_span, instrument};

use crate::coverage::graph::{BasicCoverageBlock, CoverageGraph, TraverseCoverageGraphWithLoops};

#[cfg(test)]
mod tests;

/// The coverage counter or counter expression associated with a particular
/// BCB node or BCB edge.
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
enum BcbCounter {
Counter { id: CounterId },
Expression { id: ExpressionId },
Expand Down Expand Up @@ -43,8 +47,9 @@ struct BcbExpression {
rhs: BcbCounter,
}

#[derive(Debug)]
pub(super) enum CounterIncrementSite {
/// Enum representing either a node or an edge in the coverage graph.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(super) enum Site {
Node { bcb: BasicCoverageBlock },
Edge { from_bcb: BasicCoverageBlock, to_bcb: BasicCoverageBlock },
}
Expand All @@ -54,7 +59,7 @@ pub(super) enum CounterIncrementSite {
pub(super) struct CoverageCounters {
/// List of places where a counter-increment statement should be injected
/// into MIR, each with its corresponding counter ID.
counter_increment_sites: IndexVec<CounterId, CounterIncrementSite>,
counter_increment_sites: IndexVec<CounterId, Site>,

/// Coverage counters/expressions that are associated with individual BCBs.
node_counters: IndexVec<BasicCoverageBlock, Option<BcbCounter>>,
Expand Down Expand Up @@ -83,7 +88,7 @@ impl CoverageCounters {
let mut builder = CountersBuilder::new(graph, bcb_needs_counter);
builder.make_bcb_counters();

builder.counters
builder.into_coverage_counters()
}

fn with_num_bcbs(num_bcbs: usize) -> Self {
Expand All @@ -96,27 +101,12 @@ impl CoverageCounters {
}
}

/// Shared helper used by [`Self::make_phys_node_counter`] and
/// [`Self::make_phys_edge_counter`]. Don't call this directly.
fn make_counter_inner(&mut self, site: CounterIncrementSite) -> BcbCounter {
/// Creates a new physical counter for a BCB node or edge.
fn make_phys_counter(&mut self, site: Site) -> BcbCounter {
let id = self.counter_increment_sites.push(site);
BcbCounter::Counter { id }
}

/// Creates a new physical counter for a BCB node.
fn make_phys_node_counter(&mut self, bcb: BasicCoverageBlock) -> BcbCounter {
self.make_counter_inner(CounterIncrementSite::Node { bcb })
}

/// Creates a new physical counter for a BCB edge.
fn make_phys_edge_counter(
&mut self,
from_bcb: BasicCoverageBlock,
to_bcb: BasicCoverageBlock,
) -> BcbCounter {
self.make_counter_inner(CounterIncrementSite::Edge { from_bcb, to_bcb })
}

fn make_expression(&mut self, lhs: BcbCounter, op: Op, rhs: BcbCounter) -> BcbCounter {
let new_expr = BcbExpression { lhs, op, rhs };
*self
Expand Down Expand Up @@ -182,6 +172,12 @@ impl CoverageCounters {
.reduce(|accum, counter| self.make_expression(accum, Op::Add, counter))
}

/// Creates a counter whose value is `lhs - SUM(rhs)`.
fn make_subtracted_sum(&mut self, lhs: BcbCounter, rhs: &[BcbCounter]) -> BcbCounter {
let Some(rhs_sum) = self.make_sum(rhs) else { return lhs };
self.make_expression(lhs, Op::Subtract, rhs_sum)
}

pub(super) fn num_counters(&self) -> usize {
self.counter_increment_sites.len()
}
Expand Down Expand Up @@ -218,8 +214,8 @@ impl CoverageCounters {
/// each site's corresponding counter ID.
pub(super) fn counter_increment_sites(
&self,
) -> impl Iterator<Item = (CounterId, &CounterIncrementSite)> {
self.counter_increment_sites.iter_enumerated()
) -> impl Iterator<Item = (CounterId, Site)> + Captures<'_> {
self.counter_increment_sites.iter_enumerated().map(|(id, &site)| (id, site))
}

/// Returns an iterator over the subset of BCB nodes that have been associated
Expand Down Expand Up @@ -338,24 +334,18 @@ impl<'a> CountersBuilder<'a> {
};

// For each out-edge other than the one that was chosen to get an expression,
// ensure that it has a counter (existing counter/expression or a new counter),
// and accumulate the corresponding counters into a single sum expression.
// ensure that it has a counter (existing counter/expression or a new counter).
let other_out_edge_counters = successors
.iter()
.copied()
// Skip the chosen edge, since we'll calculate its count from this sum.
.filter(|&edge_target_bcb| edge_target_bcb != target_bcb)
.map(|to_bcb| self.get_or_make_edge_counter(from_bcb, to_bcb))
.collect::<Vec<_>>();
let Some(sum_of_all_other_out_edges) = self.counters.make_sum(&other_out_edge_counters)
else {
return;
};

// Now create an expression for the chosen edge, by taking the counter
// for its source node and subtracting the sum of its sibling out-edges.
let expression =
self.counters.make_expression(node_counter, Op::Subtract, sum_of_all_other_out_edges);
let expression = self.counters.make_subtracted_sum(node_counter, &other_out_edge_counters);

debug!("{target_bcb:?} gets an expression: {expression:?}");
self.counters.set_edge_counter(from_bcb, target_bcb, expression);
Expand Down Expand Up @@ -390,7 +380,7 @@ impl<'a> CountersBuilder<'a> {
// leading to infinite recursion.
if predecessors.len() <= 1 || predecessors.contains(&bcb) {
debug!(?bcb, ?predecessors, "node has <=1 predecessors or is its own predecessor");
let counter = self.counters.make_phys_node_counter(bcb);
let counter = self.counters.make_phys_counter(Site::Node { bcb });
debug!(?bcb, ?counter, "node gets a physical counter");
return counter;
}
Expand Down Expand Up @@ -447,7 +437,7 @@ impl<'a> CountersBuilder<'a> {
}

// Make a new counter to count this edge.
let counter = self.counters.make_phys_edge_counter(from_bcb, to_bcb);
let counter = self.counters.make_phys_counter(Site::Edge { from_bcb, to_bcb });
debug!(?from_bcb, ?to_bcb, ?counter, "edge gets a physical counter");
counter
}
Expand Down Expand Up @@ -510,4 +500,126 @@ impl<'a> CountersBuilder<'a> {

None
}

fn into_coverage_counters(self) -> CoverageCounters {
Transcriber::new(&self).transcribe_counters()
}
}

/// Helper struct for converting `CountersBuilder` into a final `CoverageCounters`.
struct Transcriber<'a> {
old: &'a CountersBuilder<'a>,
new: CoverageCounters,
phys_counter_for_site: FxHashMap<Site, BcbCounter>,
}

impl<'a> Transcriber<'a> {
fn new(old: &'a CountersBuilder<'a>) -> Self {
Self {
old,
new: CoverageCounters::with_num_bcbs(old.graph.num_nodes()),
phys_counter_for_site: FxHashMap::default(),
}
}

fn transcribe_counters(mut self) -> CoverageCounters {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is basically a normalization if there are no old counters and a system for keeping the diff small if there are old counters?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of tricky to explain, but it's also the whole crux of this PR, so I should try.

We can think of this code as going through a series of refactoring stages:

  • Graph traversal → CoverageCounters (status quo)
  • Graph traversal → CoverageCountersTranscriber → simplified CoverageCounters
  • Graph traversal → FxHashMap<Site, SiteCounter>Transcriber → simplified CoverageCounters

The main goal of introducing Transcriber as a middle layer is so that the part before Transcriber can be changed to not be tied to CoverageCounters. To make that feasible, we need to go through the intermediate step of having two different CoverageCounters (old and new), so that we can then replace the first one with something else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that final CoverageCounters is simpler than the original one starts off as being a bonus extra, but it also lets the earlier steps not care so much about producing “optimal” results in a single pass. I expect that to be a big help in future changes to how counter creation works.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks!

for bcb in self.old.bcb_needs_counter.iter() {
let site = Site::Node { bcb };
let Some(old_counter) = self.old.counters.node_counters[bcb] else { continue };

// Resolve the old counter into flat lists of nodes/edges whose
// physical counts contribute to the counter for this node.
// Distinguish between counts that will be added vs subtracted.
let mut pos = vec![];
let mut neg = vec![];
self.push_resolved_sites(old_counter, &mut pos, &mut neg);

// Simplify by cancelling out sites that appear on both sides.
let (mut pos, mut neg) = sort_and_cancel(pos, neg);

if pos.is_empty() {
// If we somehow end up with no positive terms after cancellation,
// fall back to creating a physical counter. There's no known way
// for this to happen, but it's hard to confidently rule it out.
debug_assert!(false, "{site:?} has no positive counter terms");
pos = vec![Some(site)];
neg = vec![];
}

let mut new_counters_for_sites = |sites: Vec<Option<Site>>| {
sites
.into_iter()
.filter_map(|id| try { self.ensure_phys_counter(id?) })
.collect::<Vec<_>>()
};
let mut pos = new_counters_for_sites(pos);
let mut neg = new_counters_for_sites(neg);

pos.sort();
neg.sort();

let pos_counter = self.new.make_sum(&pos).expect("`pos` should not be empty");
let new_counter = self.new.make_subtracted_sum(pos_counter, &neg);
self.new.set_node_counter(bcb, new_counter);
}

self.new
}

fn ensure_phys_counter(&mut self, site: Site) -> BcbCounter {
*self.phys_counter_for_site.entry(site).or_insert_with(|| self.new.make_phys_counter(site))
}

/// Resolves the given counter into flat lists of nodes/edges, whose counters
/// will then be added and subtracted to form a counter expression.
fn push_resolved_sites(&self, counter: BcbCounter, pos: &mut Vec<Site>, neg: &mut Vec<Site>) {
match counter {
BcbCounter::Counter { id } => pos.push(self.old.counters.counter_increment_sites[id]),
BcbCounter::Expression { id } => {
let BcbExpression { lhs, op, rhs } = self.old.counters.expressions[id];
self.push_resolved_sites(lhs, pos, neg);
match op {
Op::Add => self.push_resolved_sites(rhs, pos, neg),
// Swap `neg` and `pos` so that the counter is subtracted.
Op::Subtract => self.push_resolved_sites(rhs, neg, pos),
}
}
}
}
}

/// Given two lists:
/// - Sorts each list.
/// - Converts each list to `Vec<Option<T>>`.
/// - Scans for values that appear in both lists, and cancels them out by
/// replacing matching pairs of values with `None`.
fn sort_and_cancel<T: Ord>(mut pos: Vec<T>, mut neg: Vec<T>) -> (Vec<Option<T>>, Vec<Option<T>>) {
pos.sort();
neg.sort();

// Convert to `Vec<Option<T>>`. If `T` has a niche, this should be zero-cost.
let mut pos = pos.into_iter().map(Some).collect::<Vec<_>>();
let mut neg = neg.into_iter().map(Some).collect::<Vec<_>>();

// Scan through the lists using two cursors. When either cursor reaches the
// end of its list, there can be no more equal pairs, so stop.
let mut p = 0;
let mut n = 0;
while p < pos.len() && n < neg.len() {
// If the values are equal, remove them and advance both cursors.
// Otherwise, advance whichever cursor points to the lesser value.
// (Choosing which cursor to advance relies on both lists being sorted.)
match pos[p].cmp(&neg[n]) {
Ordering::Less => p += 1,
Ordering::Equal => {
pos[p] = None;
neg[n] = None;
p += 1;
n += 1;
}
Ordering::Greater => n += 1,
}
}

(pos, neg)
}
41 changes: 41 additions & 0 deletions compiler/rustc_mir_transform/src/coverage/counters/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use std::fmt::Debug;

use super::sort_and_cancel;

fn flatten<T>(input: Vec<Option<T>>) -> Vec<T> {
input.into_iter().flatten().collect()
}

fn sort_and_cancel_and_flatten<T: Clone + Ord>(pos: Vec<T>, neg: Vec<T>) -> (Vec<T>, Vec<T>) {
let (pos_actual, neg_actual) = sort_and_cancel(pos, neg);
(flatten(pos_actual), flatten(neg_actual))
}

#[track_caller]
fn check_test_case<T: Clone + Debug + Ord>(
pos: Vec<T>,
neg: Vec<T>,
pos_expected: Vec<T>,
neg_expected: Vec<T>,
) {
eprintln!("pos = {pos:?}; neg = {neg:?}");
let output = sort_and_cancel_and_flatten(pos, neg);
assert_eq!(output, (pos_expected, neg_expected));
}

#[test]
fn cancellation() {
let cases: &[(Vec<u32>, Vec<u32>, Vec<u32>, Vec<u32>)] = &[
(vec![], vec![], vec![], vec![]),
(vec![4, 2, 1, 5, 3], vec![], vec![1, 2, 3, 4, 5], vec![]),
(vec![5, 5, 5, 5, 5], vec![5], vec![5, 5, 5, 5], vec![]),
(vec![1, 1, 2, 2, 3, 3], vec![1, 2, 3], vec![1, 2, 3], vec![]),
(vec![1, 1, 2, 2, 3, 3], vec![2, 4, 2], vec![1, 1, 3, 3], vec![4]),
];

for (pos, neg, pos_expected, neg_expected) in cases {
check_test_case(pos.to_vec(), neg.to_vec(), pos_expected.to_vec(), neg_expected.to_vec());
// Same test case, but with its inputs flipped and its outputs flipped.
check_test_case(neg.to_vec(), pos.to_vec(), neg_expected.to_vec(), pos_expected.to_vec());
}
}
10 changes: 5 additions & 5 deletions compiler/rustc_mir_transform/src/coverage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use rustc_span::source_map::SourceMap;
use rustc_span::{BytePos, Pos, SourceFile, Span};
use tracing::{debug, debug_span, trace};

use crate::coverage::counters::{CounterIncrementSite, CoverageCounters};
use crate::coverage::counters::{CoverageCounters, Site};
use crate::coverage::graph::CoverageGraph;
use crate::coverage::mappings::ExtractedMappings;

Expand Down Expand Up @@ -265,13 +265,13 @@ fn inject_coverage_statements<'tcx>(
coverage_counters: &CoverageCounters,
) {
// Inject counter-increment statements into MIR.
for (id, counter_increment_site) in coverage_counters.counter_increment_sites() {
for (id, site) in coverage_counters.counter_increment_sites() {
// Determine the block to inject a counter-increment statement into.
// For BCB nodes this is just their first block, but for edges we need
// to create a new block between the two BCBs, and inject into that.
let target_bb = match *counter_increment_site {
CounterIncrementSite::Node { bcb } => basic_coverage_blocks[bcb].leader_bb(),
CounterIncrementSite::Edge { from_bcb, to_bcb } => {
let target_bb = match site {
Site::Node { bcb } => basic_coverage_blocks[bcb].leader_bb(),
Site::Edge { from_bcb, to_bcb } => {
// Create a new block between the last block of `from_bcb` and
// the first block of `to_bcb`.
let from_bb = basic_coverage_blocks[from_bcb].last_bb();
Expand Down
26 changes: 13 additions & 13 deletions tests/coverage/abort.cov-map
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
Function name: abort::main
Raw bytes (89): 0x[01, 01, 0a, 01, 27, 05, 09, 03, 0d, 22, 11, 03, 0d, 03, 0d, 22, 15, 03, 0d, 03, 0d, 05, 09, 0d, 01, 0d, 01, 01, 1b, 03, 02, 0b, 00, 18, 22, 01, 0c, 00, 19, 11, 00, 1a, 02, 0a, 0e, 02, 09, 00, 0a, 22, 02, 0c, 00, 19, 15, 00, 1a, 00, 31, 1a, 00, 30, 00, 31, 22, 04, 0c, 00, 19, 05, 00, 1a, 00, 31, 09, 00, 30, 00, 31, 27, 01, 09, 00, 17, 0d, 02, 05, 01, 02]
Raw bytes (89): 0x[01, 01, 0a, 07, 09, 01, 05, 03, 0d, 03, 13, 0d, 11, 03, 0d, 03, 1f, 0d, 15, 03, 0d, 05, 09, 0d, 01, 0d, 01, 01, 1b, 03, 02, 0b, 00, 18, 22, 01, 0c, 00, 19, 11, 00, 1a, 02, 0a, 0e, 02, 09, 00, 0a, 22, 02, 0c, 00, 19, 15, 00, 1a, 00, 31, 1a, 00, 30, 00, 31, 22, 04, 0c, 00, 19, 05, 00, 1a, 00, 31, 09, 00, 30, 00, 31, 27, 01, 09, 00, 17, 0d, 02, 05, 01, 02]
Number of files: 1
- file 0 => global file 1
Number of expressions: 10
- expression 0 operands: lhs = Counter(0), rhs = Expression(9, Add)
- expression 1 operands: lhs = Counter(1), rhs = Counter(2)
- expression 0 operands: lhs = Expression(1, Add), rhs = Counter(2)
- expression 1 operands: lhs = Counter(0), rhs = Counter(1)
- expression 2 operands: lhs = Expression(0, Add), rhs = Counter(3)
- expression 3 operands: lhs = Expression(8, Sub), rhs = Counter(4)
- expression 4 operands: lhs = Expression(0, Add), rhs = Counter(3)
- expression 3 operands: lhs = Expression(0, Add), rhs = Expression(4, Add)
- expression 4 operands: lhs = Counter(3), rhs = Counter(4)
- expression 5 operands: lhs = Expression(0, Add), rhs = Counter(3)
- expression 6 operands: lhs = Expression(8, Sub), rhs = Counter(5)
- expression 7 operands: lhs = Expression(0, Add), rhs = Counter(3)
- expression 6 operands: lhs = Expression(0, Add), rhs = Expression(7, Add)
- expression 7 operands: lhs = Counter(3), rhs = Counter(5)
- expression 8 operands: lhs = Expression(0, Add), rhs = Counter(3)
- expression 9 operands: lhs = Counter(1), rhs = Counter(2)
Number of file 0 mappings: 13
- Code(Counter(0)) at (prev + 13, 1) to (start + 1, 27)
- Code(Expression(0, Add)) at (prev + 2, 11) to (start + 0, 24)
= (c0 + (c1 + c2))
= ((c0 + c1) + c2)
- Code(Expression(8, Sub)) at (prev + 1, 12) to (start + 0, 25)
= ((c0 + (c1 + c2)) - c3)
= (((c0 + c1) + c2) - c3)
- Code(Counter(4)) at (prev + 0, 26) to (start + 2, 10)
- Code(Expression(3, Sub)) at (prev + 2, 9) to (start + 0, 10)
= (((c0 + (c1 + c2)) - c3) - c4)
= (((c0 + c1) + c2) - (c3 + c4))
- Code(Expression(8, Sub)) at (prev + 2, 12) to (start + 0, 25)
= ((c0 + (c1 + c2)) - c3)
= (((c0 + c1) + c2) - c3)
- Code(Counter(5)) at (prev + 0, 26) to (start + 0, 49)
- Code(Expression(6, Sub)) at (prev + 0, 48) to (start + 0, 49)
= (((c0 + (c1 + c2)) - c3) - c5)
= (((c0 + c1) + c2) - (c3 + c5))
- Code(Expression(8, Sub)) at (prev + 4, 12) to (start + 0, 25)
= ((c0 + (c1 + c2)) - c3)
= (((c0 + c1) + c2) - c3)
- Code(Counter(1)) at (prev + 0, 26) to (start + 0, 49)
- Code(Counter(2)) at (prev + 0, 48) to (start + 0, 49)
- Code(Expression(9, Add)) at (prev + 1, 9) to (start + 0, 23)
Expand Down
Loading