Skip to content

Commit cb4627d

Browse files
authored
Add setBranchWeigths convenience function. NFC (llvm#72446)
Add `setBranchWeights` convenience function to ProfDataUtils.h and use it where appropriate.
1 parent 186db1b commit cb4627d

11 files changed

+48
-58
lines changed

llvm/include/llvm/IR/ProfDataUtils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,9 @@ bool extractProfTotalWeight(const MDNode *ProfileData, uint64_t &TotalWeights);
104104
/// metadata was found.
105105
bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalWeights);
106106

107+
/// Create a new `branch_weights` metadata node and add or overwrite
108+
/// a `prof` metadata reference to instruction `I`.
109+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights);
110+
107111
} // namespace llvm
108112
#endif

llvm/lib/IR/ProfDataUtils.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/IR/Function.h"
1818
#include "llvm/IR/Instructions.h"
1919
#include "llvm/IR/LLVMContext.h"
20+
#include "llvm/IR/MDBuilder.h"
2021
#include "llvm/IR/Metadata.h"
2122
#include "llvm/Support/BranchProbability.h"
2223
#include "llvm/Support/CommandLine.h"
@@ -183,4 +184,10 @@ bool extractProfTotalWeight(const Instruction &I, uint64_t &TotalVal) {
183184
return extractProfTotalWeight(I.getMetadata(LLVMContext::MD_prof), TotalVal);
184185
}
185186

187+
void setBranchWeights(Instruction &I, ArrayRef<uint32_t> Weights) {
188+
MDBuilder MDB(I.getContext());
189+
MDNode *BranchWeights = MDB.createBranchWeights(Weights);
190+
I.setMetadata(LLVMContext::MD_prof, BranchWeights);
191+
}
192+
186193
} // namespace llvm

llvm/lib/Transforms/IPO/SampleProfile.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include "llvm/IR/MDBuilder.h"
5757
#include "llvm/IR/Module.h"
5858
#include "llvm/IR/PassManager.h"
59+
#include "llvm/IR/ProfDataUtils.h"
5960
#include "llvm/IR/PseudoProbe.h"
6061
#include "llvm/IR/ValueSymbolTable.h"
6162
#include "llvm/ProfileData/InstrProf.h"
@@ -1710,20 +1711,19 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
17101711
else if (OverwriteExistingWeights)
17111712
I.setMetadata(LLVMContext::MD_prof, nullptr);
17121713
} else if (!isa<IntrinsicInst>(&I)) {
1713-
I.setMetadata(LLVMContext::MD_prof,
1714-
MDB.createBranchWeights(
1715-
{static_cast<uint32_t>(BlockWeights[BB])}));
1714+
setBranchWeights(I, {static_cast<uint32_t>(BlockWeights[BB])});
17161715
}
17171716
}
17181717
} else if (OverwriteExistingWeights || ProfileSampleBlockAccurate) {
17191718
// Set profile metadata (possibly annotated by LTO prelink) to zero or
17201719
// clear it for cold code.
17211720
for (auto &I : *BB) {
17221721
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
1723-
if (cast<CallBase>(I).isIndirectCall())
1722+
if (cast<CallBase>(I).isIndirectCall()) {
17241723
I.setMetadata(LLVMContext::MD_prof, nullptr);
1725-
else
1726-
I.setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(0));
1724+
} else {
1725+
setBranchWeights(I, {uint32_t(0)});
1726+
}
17271727
}
17281728
}
17291729
}
@@ -1803,7 +1803,7 @@ void SampleProfileLoader::generateMDProfMetadata(Function &F) {
18031803
if (MaxWeight > 0 &&
18041804
(!TI->extractProfTotalWeight(TempWeight) || OverwriteExistingWeights)) {
18051805
LLVM_DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n");
1806-
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
1806+
setBranchWeights(*TI, Weights);
18071807
ORE->emit([&]() {
18081808
return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst)
18091809
<< "most popular destination for conditional branches at "

llvm/lib/Transforms/Instrumentation/ControlHeightReduction.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1878,8 +1878,7 @@ void CHR::fixupBranchesAndSelects(CHRScope *Scope,
18781878
static_cast<uint32_t>(CHRBranchBias.scale(1000)),
18791879
static_cast<uint32_t>(CHRBranchBias.getCompl().scale(1000)),
18801880
};
1881-
MDBuilder MDB(F.getContext());
1882-
MergedBR->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
1881+
setBranchWeights(*MergedBR, Weights);
18831882
CHR_DEBUG(dbgs() << "CHR branch bias " << Weights[0] << ":" << Weights[1]
18841883
<< "\n");
18851884
}

llvm/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "llvm/IR/LLVMContext.h"
2727
#include "llvm/IR/MDBuilder.h"
2828
#include "llvm/IR/PassManager.h"
29+
#include "llvm/IR/ProfDataUtils.h"
2930
#include "llvm/IR/Value.h"
3031
#include "llvm/ProfileData/InstrProf.h"
3132
#include "llvm/Support/Casting.h"
@@ -256,10 +257,7 @@ CallBase &llvm::pgo::promoteIndirectCall(CallBase &CB, Function *DirectCallee,
256257
promoteCallWithIfThenElse(CB, DirectCallee, BranchWeights);
257258

258259
if (AttachProfToDirectCall) {
259-
MDBuilder MDB(NewInst.getContext());
260-
NewInst.setMetadata(
261-
LLVMContext::MD_prof,
262-
MDB.createBranchWeights({static_cast<uint32_t>(Count)}));
260+
setBranchWeights(NewInst, {static_cast<uint32_t>(Count)});
263261
}
264262

265263
using namespace ore;

llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1437,12 +1437,11 @@ void PGOUseFunc::populateCoverage(IndexedInstrProfReader *PGOReader) {
14371437
// If A is uncovered, set weight=1.
14381438
// This setup will allow BFI to give nonzero profile counts to only covered
14391439
// blocks.
1440-
SmallVector<unsigned, 4> Weights;
1440+
SmallVector<uint32_t, 4> Weights;
14411441
for (auto *Succ : successors(&BB))
14421442
Weights.push_back((Coverage[Succ] || !Coverage[&BB]) ? 1 : 0);
14431443
if (Weights.size() >= 2)
1444-
BB.getTerminator()->setMetadata(LLVMContext::MD_prof,
1445-
MDB.createBranchWeights(Weights));
1444+
llvm::setBranchWeights(*BB.getTerminator(), Weights);
14461445
}
14471446

14481447
unsigned NumCorruptCoverage = 0;
@@ -2205,7 +2204,6 @@ static std::string getSimpleNodeName(const BasicBlock *Node) {
22052204

22062205
void llvm::setProfMetadata(Module *M, Instruction *TI,
22072206
ArrayRef<uint64_t> EdgeCounts, uint64_t MaxCount) {
2208-
MDBuilder MDB(M->getContext());
22092207
assert(MaxCount > 0 && "Bad max count");
22102208
uint64_t Scale = calculateCountScale(MaxCount);
22112209
SmallVector<unsigned, 4> Weights;
@@ -2219,7 +2217,7 @@ void llvm::setProfMetadata(Module *M, Instruction *TI,
22192217

22202218
misexpect::checkExpectAnnotations(*TI, Weights, /*IsFrontend=*/false);
22212219

2222-
TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights));
2220+
setBranchWeights(*TI, Weights);
22232221
if (EmitBranchProbability) {
22242222
std::string BrCondStr = getBranchCondString(TI);
22252223
if (BrCondStr.empty())

llvm/lib/Transforms/Scalar/JumpThreading.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -228,17 +228,15 @@ static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) {
228228
if (BP >= BranchProbability(50, 100))
229229
continue;
230230

231-
SmallVector<uint32_t, 2> Weights;
231+
uint32_t Weights[2];
232232
if (PredBr->getSuccessor(0) == PredOutEdge.second) {
233-
Weights.push_back(BP.getNumerator());
234-
Weights.push_back(BP.getCompl().getNumerator());
233+
Weights[0] = BP.getNumerator();
234+
Weights[1] = BP.getCompl().getNumerator();
235235
} else {
236-
Weights.push_back(BP.getCompl().getNumerator());
237-
Weights.push_back(BP.getNumerator());
236+
Weights[0] = BP.getCompl().getNumerator();
237+
Weights[1] = BP.getNumerator();
238238
}
239-
PredBr->setMetadata(LLVMContext::MD_prof,
240-
MDBuilder(PredBr->getParent()->getContext())
241-
.createBranchWeights(Weights));
239+
setBranchWeights(*PredBr, Weights);
242240
}
243241
}
244242

@@ -2574,9 +2572,7 @@ void JumpThreadingPass::updateBlockFreqAndEdgeWeight(BasicBlock *PredBB,
25742572
Weights.push_back(Prob.getNumerator());
25752573

25762574
auto TI = BB->getTerminator();
2577-
TI->setMetadata(
2578-
LLVMContext::MD_prof,
2579-
MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights));
2575+
setBranchWeights(*TI, Weights);
25802576
}
25812577
}
25822578

llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "llvm/IR/Intrinsics.h"
2121
#include "llvm/IR/LLVMContext.h"
2222
#include "llvm/IR/MDBuilder.h"
23+
#include "llvm/IR/ProfDataUtils.h"
2324
#include "llvm/Support/CommandLine.h"
2425
#include "llvm/Transforms/Utils/MisExpect.h"
2526

@@ -101,10 +102,7 @@ static bool handleSwitchExpect(SwitchInst &SI) {
101102
misexpect::checkExpectAnnotations(SI, Weights, /*IsFrontend=*/true);
102103

103104
SI.setCondition(ArgValue);
104-
105-
SI.setMetadata(LLVMContext::MD_prof,
106-
MDBuilder(CI->getContext()).createBranchWeights(Weights));
107-
105+
setBranchWeights(SI, Weights);
108106
return true;
109107
}
110108

llvm/lib/Transforms/Utils/Local.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,7 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions,
227227
// Remove weight for this case.
228228
std::swap(Weights[Idx + 1], Weights.back());
229229
Weights.pop_back();
230-
SI->setMetadata(LLVMContext::MD_prof,
231-
MDBuilder(BB->getContext()).
232-
createBranchWeights(Weights));
230+
setBranchWeights(*SI, Weights);
233231
}
234232
// Remove this entry.
235233
BasicBlock *ParentBB = SI->getParent();

llvm/lib/Transforms/Utils/LoopPeel.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,7 @@ struct WeightInfo {
631631
/// To avoid dealing with division rounding we can just multiple both part
632632
/// of weights to E and use weight as (F - I * E, E).
633633
static void updateBranchWeights(Instruction *Term, WeightInfo &Info) {
634-
MDBuilder MDB(Term->getContext());
635-
Term->setMetadata(LLVMContext::MD_prof,
636-
MDB.createBranchWeights(Info.Weights));
634+
setBranchWeights(*Term, Info.Weights);
637635
for (auto [Idx, SubWeight] : enumerate(Info.SubWeights))
638636
if (SubWeight != 0)
639637
// Don't set the probability of taking the edge from latch to loop header
@@ -690,14 +688,6 @@ static void initBranchWeights(DenseMap<Instruction *, WeightInfo> &WeightInfos,
690688
}
691689
}
692690

693-
/// Update the weights of original exiting block after peeling off all
694-
/// iterations.
695-
static void fixupBranchWeights(Instruction *Term, const WeightInfo &Info) {
696-
MDBuilder MDB(Term->getContext());
697-
Term->setMetadata(LLVMContext::MD_prof,
698-
MDB.createBranchWeights(Info.Weights));
699-
}
700-
701691
/// Clones the body of the loop L, putting it between \p InsertTop and \p
702692
/// InsertBot.
703693
/// \param IterNumber The serial number of the iteration currently being
@@ -1033,8 +1023,9 @@ bool llvm::peelLoop(Loop *L, unsigned PeelCount, LoopInfo *LI,
10331023
PHI->setIncomingValueForBlock(NewPreHeader, NewVal);
10341024
}
10351025

1036-
for (const auto &[Term, Info] : Weights)
1037-
fixupBranchWeights(Term, Info);
1026+
for (const auto &[Term, Info] : Weights) {
1027+
setBranchWeights(*Term, Info.Weights);
1028+
}
10381029

10391030
// Update Metadata for count of peeled off iterations.
10401031
unsigned AlreadyPeeled = 0;

llvm/lib/Transforms/Utils/LoopRotationUtils.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,16 +352,17 @@ static void updateBranchWeights(BranchInst &PreHeaderBI, BranchInst &LoopBI,
352352
LoopBackWeight = 0;
353353
}
354354

355-
MDBuilder MDB(LoopBI.getContext());
356-
MDNode *LoopWeightMD =
357-
MDB.createBranchWeights(SuccsSwapped ? LoopBackWeight : ExitWeight1,
358-
SuccsSwapped ? ExitWeight1 : LoopBackWeight);
359-
LoopBI.setMetadata(LLVMContext::MD_prof, LoopWeightMD);
355+
const uint32_t LoopBIWeights[] = {
356+
SuccsSwapped ? LoopBackWeight : ExitWeight1,
357+
SuccsSwapped ? ExitWeight1 : LoopBackWeight,
358+
};
359+
setBranchWeights(LoopBI, LoopBIWeights);
360360
if (HasConditionalPreHeader) {
361-
MDNode *PreHeaderWeightMD =
362-
MDB.createBranchWeights(SuccsSwapped ? EnterWeight : ExitWeight0,
363-
SuccsSwapped ? ExitWeight0 : EnterWeight);
364-
PreHeaderBI.setMetadata(LLVMContext::MD_prof, PreHeaderWeightMD);
361+
const uint32_t PreHeaderBIWeights[] = {
362+
SuccsSwapped ? EnterWeight : ExitWeight0,
363+
SuccsSwapped ? ExitWeight0 : EnterWeight,
364+
};
365+
setBranchWeights(PreHeaderBI, PreHeaderBIWeights);
365366
}
366367
}
367368

0 commit comments

Comments
 (0)