Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PGO] Sampled instrumentation in PGO to speed up instrumentation binary #69535

Merged
merged 6 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm/ProfileData/InstrProfData.inc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,7 @@ serializeValueProfDataFrom(ValueProfRecordClosure *Closure,
#define INSTR_PROF_PROFILE_RUNTIME_VAR __llvm_profile_runtime
#define INSTR_PROF_PROFILE_COUNTER_BIAS_VAR __llvm_profile_counter_bias
#define INSTR_PROF_PROFILE_SET_TIMESTAMP __llvm_profile_set_timestamp
#define INSTR_PROF_PROFILE_SAMPLING_VAR __llvm_profile_sampling

/* The variable that holds the name of the profile data
* specified via command line. */
Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/Transforms/Instrumentation.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,18 @@ struct InstrProfOptions {
// Use BFI to guide register promotion
bool UseBFIInPromotion = false;

// Use sampling to reduce the profile instrumentation runtime overhead.
bool Sampling = false;

// Name of the profile file to use as output
std::string InstrProfileOutput;

InstrProfOptions() = default;
};

// Create the variable for profile sampling.
void createProfileSamplingVar(Module &M);

// Options for sanitizer coverage instrumentation.
struct SanitizerCoverageOptions {
enum Type {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ class FileSystem;
class PGOInstrumentationGenCreateVar
: public PassInfoMixin<PGOInstrumentationGenCreateVar> {
public:
PGOInstrumentationGenCreateVar(std::string CSInstrName = "")
: CSInstrName(CSInstrName) {}
PGOInstrumentationGenCreateVar(std::string CSInstrName = "",
bool Sampling = false)
: CSInstrName(CSInstrName), ProfileSampling(Sampling) {}
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);

private:
std::string CSInstrName;
bool ProfileSampling;
};

/// The instrumentation (profile-instr-gen) pass for IR based PGO.
Expand Down
12 changes: 11 additions & 1 deletion llvm/lib/Passes/PassBuilderPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,9 @@ static cl::opt<AttributorRunOption> AttributorRun(
clEnumValN(AttributorRunOption::NONE, "none",
"disable attributor runs")));

static cl::opt<bool> EnableSampledInstr(
"enable-sampled-instr", cl::init(false), cl::Hidden,
cl::desc("Enable profile instrumentation sampling (default = off)"));
static cl::opt<bool> UseLoopVersioningLICM(
"enable-loop-versioning-licm", cl::init(false), cl::Hidden,
cl::desc("Enable the experimental Loop Versioning LICM pass"));
Expand Down Expand Up @@ -847,6 +850,12 @@ void PassBuilder::addPGOInstrPasses(ModulePassManager &MPM,
// Do counter promotion at Level greater than O0.
Options.DoCounterPromotion = true;
Options.UseBFIInPromotion = IsCS;
if (EnableSampledInstr) {
Options.Sampling = true;
// With sampling, there is little beneifit to enable counter promotion.
// But note that sampling does work with counter promotion.
Options.DoCounterPromotion = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on why counter promotion is turned off?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is mentioned in InstrProfling.C:400.

}
Options.Atomic = AtomicCounterUpdate;
MPM.addPass(InstrProfilingLoweringPass(Options, IsCS));
}
Expand Down Expand Up @@ -1185,7 +1194,8 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
MPM.addPass(PGOIndirectCallPromotion(false, false));

if (IsPGOPreLink && PGOOpt->CSAction == PGOOptions::CSIRInstr)
MPM.addPass(PGOInstrumentationGenCreateVar(PGOOpt->CSProfileGenFile));
MPM.addPass(PGOInstrumentationGenCreateVar(PGOOpt->CSProfileGenFile,
EnableSampledInstr));

if (IsMemprofUse)
MPM.addPass(MemProfUsePass(PGOOpt->MemoryProfile, PGOOpt->FS));
Expand Down
236 changes: 214 additions & 22 deletions llvm/lib/Transforms/Instrumentation/InstrProfiling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,29 @@ cl::opt<bool> SkipRetExitBlock(
"skip-ret-exit-block", cl::init(true),
cl::desc("Suppress counter promotion if exit blocks contain ret."));

static cl::opt<bool> SampledInstr("sampled-instr", cl::ZeroOrMore,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'instr' can be confused with 'instruction'. We should just spell it out as 'sampled-instrumentation'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack.

cl::init(false),
cl::desc("Do PGO instrumentation sampling"));

static cl::opt<unsigned> SampledInstrPeriod(
"sampled-instr-period",
cl::desc("Set the profile instrumentation sample period. For each sample "
"period, the 'sampled-instr-burst-duration' number of consecutive "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fixed number of consecutive samples will be record. The number is controlled by 'sampled-instr-burst-duration' flag.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack.

"samples will be recorded. The default sample period of 65535 is "
"optimized for generating efficient code that leverages unsigned "
"integer wrapping in overflow."),
cl::init(65535));

static cl::opt<unsigned> SampledInstrBurstDuration(
"sampled-instr-burst-duration",
cl::desc("Set the profile instrumentation burst duration, which can range "
"from 0 to one less than the value of 'sampled-instr-period'. "
"This number of samples will be recorded for each "
"'sampled-instr-period' count update. Setting to 1 enables "
"simple sampling, in which case it is recommended to set "
"'sampled-instr-period' to a prime number."),
cl::init(200));

using LoadStorePair = std::pair<Instruction *, Instruction *>;

static uint64_t getIntModuleFlagOrZero(const Module &M, StringRef Flag) {
Expand Down Expand Up @@ -260,6 +283,9 @@ class InstrLowerer final {
/// Returns true if profile counter update register promotion is enabled.
bool isCounterPromotionEnabled() const;

/// Return true if profile sampling is enabled.
bool isSamplingEnabled() const;

/// Count the number of instrumented value sites for the function.
void computeNumValueSiteCounts(InstrProfValueProfileInst *Ins);

Expand Down Expand Up @@ -291,6 +317,9 @@ class InstrLowerer final {
/// acts on.
Value *getCounterAddress(InstrProfCntrInstBase *I);

/// Lower the incremental instructions under profile sampling predicates.
void doSampling(Instruction *I);

/// Get the region counters for an increment, creating them if necessary.
///
/// If the counter array doesn't yet exist, the profile data variables
Expand Down Expand Up @@ -635,33 +664,161 @@ PreservedAnalyses InstrProfilingLoweringPass::run(Module &M,
return PreservedAnalyses::none();
}

//
// Perform instrumentation sampling.
//
// There are 3 favors of sampling:
// (1) Full burst sampling: We transform:
// Increment_Instruction;
// to:
// if (__llvm_profile_sampling__ < SampledInstrBurstDuration) {
// Increment_Instruction;
// }
// __llvm_profile_sampling__ += 1;
// if (__llvm_profile_sampling__ >= SampledInstrPeriod) {
// __llvm_profile_sampling__ = 0;
// }
//
// "__llvm_profile_sampling__" is a thread-local global shared by all PGO
// counters (value-instrumentation and edge instrumentation).
//
// (2) Fast burst sampling:
// The value is an unsigned type, meaning it will wrap around to zero when
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm_profile_sampling variable is ..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack.

// overflows. In this case, a second check (check2) is unnecessary, so we
// won't generate check2 when the SampledInstrPeriod is set to 65535 (64K - 1).
// The code after:
// if (__llvm_profile_sampling__ < SampledInstrBurstDuration) {
// Increment_Instruction;
// }
// __llvm_profile_sampling__ += 1;
//
// (3) Simple sampling:
// When SampledInstrBurstDuration sets to 1, we do a simple sampling:
// __llvm_profile_sampling__ += 1;
// if (__llvm_profile_sampling__ >= SampledInstrPeriod) {
// __llvm_profile_sampling__ = 0;
// Increment_Instruction;
// }
//
// Note that, the code snippet after the transformation can still be counter
// promoted. However, with sampling enabled, counter updates are expected to
// be infrequent, making the benefits of counter promotion negligible.
// Moreover, counter promotion can potentially cause issues in server
// applications, particularly when the counters are dumped without a clean
// exit. To mitigate this risk, counter promotion is disabled by default when
// sampling is enabled. This behavior can be overridden using the internal
// option.
void InstrLowerer::doSampling(Instruction *I) {
if (!isSamplingEnabled())
return;

unsigned SampledBurstDuration = SampledInstrBurstDuration.getValue();
unsigned SampledPeriod = SampledInstrPeriod.getValue();
assert(SampledBurstDuration < SampledPeriod);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we emit an error? assertion is probably not enough for validating user input.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ack. Will change to an error.

bool UseShort = (SampledPeriod <= USHRT_MAX);
bool IsSimpleSampling = (SampledBurstDuration == 1);
bool IsFastSampling = (!IsSimpleSampling && SampledPeriod == 65535);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is fast means we don't need a check for period, and instead rely on overflow. In that case, value of SampledBurstDuration is unrelated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The value of SampleBurstDuration is number of samples being recored for each duration. The value will be used in the condition generated.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition for SampledBurstDuration is generated for non-simple sampling, regardless whether it's fast sampling. My question was why do we need to include !IsSimpleSampling in IsFastSampling..

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The condition is to handle the case where both SampledBurstDuration==1 and SampledPeriod=65535.
I chose to have SimpleSamping over FastSampling. This is just a implementation choice.
Note that we need one condition for the increment with either choices. The difference is either using
the first condition (FastBurst style), or the second condition (simple sampling style). I don't think that
matters.


auto GetConstant = [UseShort](IRBuilder<> &Builder, uint32_t C) {
if (UseShort)
return Builder.getInt16(C);
else
return Builder.getInt32(C);
};

IntegerType *SamplingVarTy;
if (UseShort)
SamplingVarTy = Type::getInt16Ty(M.getContext());
else
SamplingVarTy = Type::getInt32Ty(M.getContext());
auto *SamplingVar =
M.getGlobalVariable(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SAMPLING_VAR));
assert(SamplingVar && "SamplingVar not set properly");

// Create the condition for checking the burst duration.
Instruction *SamplingVarIncr;
Value *NewSamplingVarVal;
MDBuilder MDB(I->getContext());
MDNode *BranchWeight;
IRBuilder<> CondBuilder(I);
auto *LoadSamplingVar = CondBuilder.CreateLoad(SamplingVarTy, SamplingVar);
if (IsSimpleSampling) {
// For the simple sampling, just create the load and increments.
IRBuilder<> IncBuilder(I);
NewSamplingVarVal =
IncBuilder.CreateAdd(LoadSamplingVar, GetConstant(IncBuilder, 1));
SamplingVarIncr = IncBuilder.CreateStore(NewSamplingVarVal, SamplingVar);
} else {
// For the bust-sampling, create the conditonal update.
auto *DurationCond = CondBuilder.CreateICmpULE(
LoadSamplingVar, GetConstant(CondBuilder, SampledBurstDuration));
BranchWeight = MDB.createBranchWeights(
SampledBurstDuration, SampledPeriod + 1 - SampledBurstDuration);
Instruction *ThenTerm = SplitBlockAndInsertIfThen(
DurationCond, I, /* Unreachable */ false, BranchWeight);
IRBuilder<> IncBuilder(I);
NewSamplingVarVal =
IncBuilder.CreateAdd(LoadSamplingVar, GetConstant(IncBuilder, 1));
SamplingVarIncr = IncBuilder.CreateStore(NewSamplingVarVal, SamplingVar);
I->moveBefore(ThenTerm);
}

if (IsFastSampling)
return;

// Create the condtion for checking the period.
Instruction *ThenTerm, *ElseTerm;
IRBuilder<> PeriodCondBuilder(SamplingVarIncr);
auto *PeriodCond = PeriodCondBuilder.CreateICmpUGE(
NewSamplingVarVal, GetConstant(PeriodCondBuilder, SampledPeriod));
BranchWeight = MDB.createBranchWeights(1, SampledPeriod);
SplitBlockAndInsertIfThenElse(PeriodCond, SamplingVarIncr, &ThenTerm,
&ElseTerm, BranchWeight);

// For the simple sampling, the counter update happens in sampling var reset.
if (IsSimpleSampling)
I->moveBefore(ThenTerm);

IRBuilder<> ResetBuilder(ThenTerm);
ResetBuilder.CreateStore(GetConstant(ResetBuilder, 0), SamplingVar);
SamplingVarIncr->moveBefore(ElseTerm);
}

bool InstrLowerer::lowerIntrinsics(Function *F) {
bool MadeChange = false;
PromotionCandidates.clear();
SmallVector<InstrProfInstBase *, 8> InstrProfInsts;

for (BasicBlock &BB : *F) {
for (Instruction &Instr : llvm::make_early_inc_range(BB)) {
if (auto *IPIS = dyn_cast<InstrProfIncrementInstStep>(&Instr)) {
lowerIncrement(IPIS);
MadeChange = true;
} else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(&Instr)) {
lowerIncrement(IPI);
MadeChange = true;
} else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(&Instr)) {
lowerTimestamp(IPC);
MadeChange = true;
} else if (auto *IPC = dyn_cast<InstrProfCoverInst>(&Instr)) {
lowerCover(IPC);
MadeChange = true;
} else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(&Instr)) {
lowerValueProfileInst(IPVP);
MadeChange = true;
} else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(&Instr)) {
IPMP->eraseFromParent();
MadeChange = true;
} else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(&Instr)) {
lowerMCDCTestVectorBitmapUpdate(IPBU);
MadeChange = true;
}
if (auto *IP = dyn_cast<InstrProfInstBase>(&Instr))
InstrProfInsts.push_back(IP);
xur-llvm marked this conversation as resolved.
Show resolved Hide resolved
}
}

for (auto *Instr : InstrProfInsts) {
doSampling(Instr);
if (auto *IPIS = dyn_cast<InstrProfIncrementInstStep>(Instr)) {
lowerIncrement(IPIS);
MadeChange = true;
} else if (auto *IPI = dyn_cast<InstrProfIncrementInst>(Instr)) {
lowerIncrement(IPI);
MadeChange = true;
} else if (auto *IPC = dyn_cast<InstrProfTimestampInst>(Instr)) {
lowerTimestamp(IPC);
MadeChange = true;
} else if (auto *IPC = dyn_cast<InstrProfCoverInst>(Instr)) {
lowerCover(IPC);
MadeChange = true;
} else if (auto *IPVP = dyn_cast<InstrProfValueProfileInst>(Instr)) {
lowerValueProfileInst(IPVP);
MadeChange = true;
} else if (auto *IPMP = dyn_cast<InstrProfMCDCBitmapParameters>(Instr)) {
IPMP->eraseFromParent();
MadeChange = true;
} else if (auto *IPBU = dyn_cast<InstrProfMCDCTVBitmapUpdate>(Instr)) {
lowerMCDCTestVectorBitmapUpdate(IPBU);
MadeChange = true;
}
}

Expand All @@ -684,6 +841,12 @@ bool InstrLowerer::isRuntimeCounterRelocationEnabled() const {
return TT.isOSFuchsia();
}

bool InstrLowerer::isSamplingEnabled() const {
if (SampledInstr.getNumOccurrences() > 0)
return SampledInstr;
return Options.Sampling;
}

bool InstrLowerer::isCounterPromotionEnabled() const {
if (DoCounterPromotion.getNumOccurrences() > 0)
return DoCounterPromotion;
Expand Down Expand Up @@ -754,6 +917,9 @@ bool InstrLowerer::lower() {
if (NeedsRuntimeHook)
MadeChange = emitRuntimeHook();

if (!IsCS && isSamplingEnabled())
createProfileSamplingVar(M);

bool ContainsProfiling = containsProfilingIntrinsics(M);
GlobalVariable *CoverageNamesVar =
M.getNamedGlobal(getCoverageUnusedNamesVarName());
Expand Down Expand Up @@ -1952,3 +2118,29 @@ void InstrLowerer::emitInitialization() {

appendToGlobalCtors(M, F, 0);
}

namespace llvm {
// Create the variable for profile sampling.
void createProfileSamplingVar(Module &M) {
const StringRef VarName(INSTR_PROF_QUOTE(INSTR_PROF_PROFILE_SAMPLING_VAR));
IntegerType *SamplingVarTy;
Constant *ValueZero;
if (SampledInstrPeriod.getValue() <= USHRT_MAX) {
SamplingVarTy = Type::getInt16Ty(M.getContext());
ValueZero = Constant::getIntegerValue(SamplingVarTy, APInt(16, 0));
} else {
SamplingVarTy = Type::getInt32Ty(M.getContext());
ValueZero = Constant::getIntegerValue(SamplingVarTy, APInt(32, 0));
}
auto SamplingVar = new GlobalVariable(
M, SamplingVarTy, false, GlobalValue::WeakAnyLinkage, ValueZero, VarName);
SamplingVar->setVisibility(GlobalValue::DefaultVisibility);
SamplingVar->setThreadLocal(true);
Triple TT(M.getTargetTriple());
if (TT.supportsCOMDAT()) {
SamplingVar->setLinkage(GlobalValue::ExternalLinkage);
SamplingVar->setComdat(M.getOrInsertComdat(VarName));
}
appendToCompilerUsed(M, SamplingVar);
}
} // namespace llvm
2 changes: 2 additions & 0 deletions llvm/lib/Transforms/Instrumentation/PGOInstrumentation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,8 @@ PGOInstrumentationGenCreateVar::run(Module &M, ModuleAnalysisManager &MAM) {
// The variable in a comdat may be discarded by LTO. Ensure the declaration
// will be retained.
appendToCompilerUsed(M, createIRLevelProfileFlagVar(M, /*IsCS=*/true));
if (ProfileSampling)
createProfileSamplingVar(M);
PreservedAnalyses PA;
PA.preserve<FunctionAnalysisManagerModuleProxy>();
PA.preserveSet<AllAnalysesOn<Function>>();
Expand Down
Loading
Loading