@@ -5,15 +5,15 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
55use rustc_hir:: lang_items;
66use rustc_middle:: hir;
77use rustc_middle:: ich:: StableHashingContext ;
8- use rustc_middle:: mir:: interpret:: Scalar ;
8+ use rustc_middle:: mir:: interpret:: { ConstValue , Scalar } ;
99use rustc_middle:: mir:: {
1010 self , traversal, BasicBlock , BasicBlockData , CoverageData , Operand , Place , SourceInfo ,
1111 StatementKind , Terminator , TerminatorKind , START_BLOCK ,
1212} ;
1313use rustc_middle:: ty;
1414use rustc_middle:: ty:: query:: Providers ;
15- use rustc_middle:: ty:: FnDef ;
1615use rustc_middle:: ty:: TyCtxt ;
16+ use rustc_middle:: ty:: { ConstKind , FnDef } ;
1717use rustc_span:: def_id:: DefId ;
1818use rustc_span:: Span ;
1919
@@ -26,16 +26,36 @@ pub struct InstrumentCoverage;
2626pub ( crate ) fn provide ( providers : & mut Providers < ' _ > ) {
2727 providers. coverage_data = |tcx, def_id| {
2828 let mir_body = tcx. optimized_mir ( def_id) ;
29+ // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
30+ // represents a single function. Validate and/or correct if inlining and/or monomorphization
31+ // invalidates these assumptions.
2932 let count_code_region_fn =
3033 tcx. require_lang_item ( lang_items:: CountCodeRegionFnLangItem , None ) ;
3134 let mut num_counters: u32 = 0 ;
35+ // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
36+ // counters, with each counter having an index from `0..num_counters-1`. MIR optimization
37+ // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
38+ // not work; but computing the num_counters by adding `1` to the highest index (for a given
39+ // instrumented function) is valid.
3240 for ( _, data) in traversal:: preorder ( mir_body) {
3341 if let Some ( terminator) = & data. terminator {
34- if let TerminatorKind :: Call { func : Operand :: Constant ( func) , .. } = & terminator. kind
42+ if let TerminatorKind :: Call { func : Operand :: Constant ( func) , args, .. } =
43+ & terminator. kind
3544 {
3645 if let FnDef ( called_fn_def_id, _) = func. literal . ty . kind {
3746 if called_fn_def_id == count_code_region_fn {
38- num_counters += 1 ;
47+ if let Operand :: Constant ( constant) =
48+ args. get ( 0 ) . expect ( "count_code_region has at least one arg" )
49+ {
50+ if let ConstKind :: Value ( ConstValue :: Scalar ( value) ) =
51+ constant. literal . val
52+ {
53+ let index = value
54+ . to_u32 ( )
55+ . expect ( "count_code_region index at arg0 is u32" ) ;
56+ num_counters = std:: cmp:: max ( num_counters, index + 1 ) ;
57+ }
58+ }
3959 }
4060 }
4161 }
0 commit comments