Skip to content

Commit cc5c5cc

Browse files
committed
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode
This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation.
1 parent 1fcfe48 commit cc5c5cc

File tree

4 files changed

+420
-34
lines changed

4 files changed

+420
-34
lines changed

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
22372237
int32_t MinThreads = 1;
22382238
};
22392239

2240+
/// Container to pass LLVM IR runtime values or constants related to the
2241+
/// number of teams and threads with which the kernel must be launched, as
2242+
/// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
2243+
/// must be defined in the host prior to the call to the kernel launch OpenMP
2244+
/// RTL function.
2245+
struct TargetKernelRuntimeAttrs {
2246+
SmallVector<Value *, 3> MaxTeams = {nullptr};
2247+
Value *MinTeams = nullptr;
2248+
SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
2249+
SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
2250+
2251+
/// 'parallel' construct 'num_threads' clause value, if present and it is a
2252+
/// target SPMD kernel.
2253+
Value *MaxThreads = nullptr;
2254+
2255+
/// Total number of iterations of the target SPMD kernel or null if it is a
2256+
/// generic kernel.
2257+
Value *LoopTripCount = nullptr;
2258+
};
2259+
22402260
/// Data structure that contains the needed information to construct the
22412261
/// kernel args vector.
22422262
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
29052925
///
29062926
/// \param Loc where the target data construct was encountered.
29072927
/// \param IsOffloadEntry whether it is an offload entry.
2928+
/// \param IsSPMD whether it is a target SPMD kernel.
29082929
/// \param CodeGenIP The insertion point where the call to the outlined
29092930
/// function should be emitted.
29102931
/// \param EntryInfo The entry information about the function.
29112932
/// \param DefaultAttrs Structure containing the default numbers of threads
29122933
/// and teams to launch the kernel with.
2934+
/// \param RuntimeAttrs Structure containing the runtime numbers of threads
2935+
/// and teams to launch the kernel with.
29132936
/// \param Inputs The input values to the region that will be passed.
29142937
/// as arguments to the outlined function.
29152938
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
29192942
// dependency information as passed in the depend clause
29202943
// \param HasNowait Whether the target construct has a `nowait` clause or not.
29212944
InsertPointOrErrorTy createTarget(
2922-
const LocationDescription &Loc, bool IsOffloadEntry,
2945+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
29232946
OpenMPIRBuilder::InsertPointTy AllocaIP,
29242947
OpenMPIRBuilder::InsertPointTy CodeGenIP,
29252948
TargetRegionEntryInfo &EntryInfo,
29262949
const TargetKernelDefaultAttrs &DefaultAttrs,
2950+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
29272951
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
29282952
TargetBodyGenCallbackTy BodyGenCB,
29292953
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 118 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67276727
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
67286728
}
67296729

6730+
static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6731+
Module &M) {
6732+
if (List.empty())
6733+
return;
6734+
6735+
Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
6736+
6737+
// Convert List to what ConstantArray needs.
6738+
SmallVector<Constant *, 8> UsedArray;
6739+
UsedArray.reserve(List.size());
6740+
for (auto Item : List)
6741+
UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
6742+
cast<Constant>(&*Item), PtrTy));
6743+
6744+
ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
6745+
auto *GV =
6746+
new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
6747+
llvm::ConstantArray::get(ArrTy, UsedArray), Name);
6748+
6749+
GV->setSection("llvm.metadata");
6750+
}
6751+
6752+
static void
6753+
emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6754+
StringRef FunctionName, OMPTgtExecModeFlags Mode,
6755+
std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6756+
auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
6757+
auto *GVMode = new llvm::GlobalVariable(
6758+
OMPBuilder.M, Int8Ty, /*isConstant=*/true,
6759+
llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
6760+
Twine(FunctionName, "_exec_mode"));
6761+
GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
6762+
LLVMCompilerUsed.emplace_back(GVMode);
6763+
}
6764+
67306765
static Expected<Function *> createOutlinedFunction(
6731-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6766+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
67326767
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
67336768
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67346769
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
67586793
auto Func =
67596794
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
67606795

6796+
// Forward target-cpu and target-features function attributes from the
6797+
// original function to the new outlined function.
6798+
Function *ParentFn = Builder.GetInsertBlock()->getParent();
6799+
6800+
auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
6801+
if (TargetCpuAttr.isStringAttribute())
6802+
Func->addFnAttr(TargetCpuAttr);
6803+
6804+
auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
6805+
if (TargetFeaturesAttr.isStringAttribute())
6806+
Func->addFnAttr(TargetFeaturesAttr);
6807+
6808+
if (OMPBuilder.Config.isTargetDevice()) {
6809+
std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6810+
emitExecutionMode(OMPBuilder, Builder, FuncName,
6811+
IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6812+
: OMP_TGT_EXEC_MODE_GENERIC,
6813+
LLVMCompilerUsed);
6814+
emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
6815+
}
6816+
67616817
// Save insert point.
67626818
IRBuilder<>::InsertPointGuard IPG(Builder);
67636819
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
67986854
// Insert target init call in the device compilation pass.
67996855
if (OMPBuilder.Config.isTargetDevice())
68006856
Builder.restoreIP(
6801-
OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
6857+
OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
68026858

68036859
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68046860

@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69957051

69967052
static Error emitTargetOutlinedFunction(
69977053
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6998-
TargetRegionEntryInfo &EntryInfo,
7054+
bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
69997055
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
70007056
Function *&OutlinedFn, Constant *&OutlinedFnID,
70017057
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
70047060

70057061
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
70067062
[&](StringRef EntryFnName) {
7007-
return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7063+
return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
70087064
EntryFnName, Inputs, CBFunc,
70097065
ArgAccessorFuncCB);
70107066
};
@@ -7304,6 +7360,7 @@ static void
73047360
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73057361
OpenMPIRBuilder::InsertPointTy AllocaIP,
73067362
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7363+
const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
73077364
Function *OutlinedFn, Constant *OutlinedFnID,
73087365
SmallVectorImpl<Value *> &Args,
73097366
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73857442
/*ForEndCall=*/false);
73867443

73877444
SmallVector<Value *, 3> NumTeamsC;
7445+
for (auto [DefaultVal, RuntimeVal] :
7446+
zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
7447+
NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
7448+
7449+
// Calculate number of threads: 0 if no clauses specified, otherwise it is the
7450+
// minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7451+
auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7452+
if (Clause)
7453+
Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
7454+
/*isSigned=*/false);
7455+
return Clause;
7456+
};
7457+
auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7458+
if (Clause)
7459+
Result = Result
7460+
? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
7461+
Result, Clause)
7462+
: Clause;
7463+
};
7464+
7465+
// If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7466+
// the NUM_THREADS clause is overriden by THREAD_LIMIT.
73887467
SmallVector<Value *, 3> NumThreadsC;
7389-
for (auto V : DefaultAttrs.MaxTeams)
7390-
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7391-
for (auto V : DefaultAttrs.MaxThreads)
7392-
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7468+
Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
7469+
? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
7470+
: nullptr;
7471+
7472+
for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
7473+
RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
7474+
Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
7475+
Value *NumThreads = InitMaxThreadsClause(TargetVal);
7476+
7477+
CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
7478+
CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
7479+
7480+
NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
7481+
}
73937482

73947483
unsigned NumTargetItems = Info.NumberOfPtrs;
73957484
// TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73987487
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
73997488
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
74007489
llvm::omp::IdentFlag(0), 0);
7401-
// TODO: Use correct NumIterations
7402-
Value *NumIterations = Builder.getInt64(0);
7490+
7491+
Value *TripCount = RuntimeAttrs.LoopTripCount
7492+
? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
7493+
Builder.getInt64Ty(),
7494+
/*isSigned=*/false)
7495+
: Builder.getInt64(0);
7496+
74037497
// TODO: Use correct DynCGGroupMem
74047498
Value *DynCGGroupMem = Builder.getInt32(0);
74057499

7406-
KArgs = OpenMPIRBuilder::TargetKernelArgs(
7407-
NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7408-
DynCGGroupMem, HasNoWait);
7500+
KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
7501+
NumTeamsC, NumThreadsC,
7502+
DynCGGroupMem, HasNoWait);
74097503

74107504
// The presence of certain clauses on the target directive require the
74117505
// explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74277521
}
74287522

74297523
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7430-
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7431-
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7524+
const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7525+
InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7526+
TargetRegionEntryInfo &EntryInfo,
74327527
const TargetKernelDefaultAttrs &DefaultAttrs,
7528+
const TargetKernelRuntimeAttrs &RuntimeAttrs,
74337529
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74347530
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74357531
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
74367532
SmallVector<DependData> Dependencies, bool HasNowait) {
7533+
assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7534+
"trip count not expected if IsSPMD=false");
74377535

74387536
if (!updateToLocation(Loc))
74397537
return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74467544
// the target region itself is generated using the callbacks CBFunc
74477545
// and ArgAccessorFuncCB
74487546
if (Error Err = emitTargetOutlinedFunction(
7449-
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7450-
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7547+
*this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
7548+
OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74517549
return Err;
74527550

74537551
// If we are not on the target device, then we need to generate code
74547552
// to make a remote call (offload) to the previously outlined function
74557553
// that represents the target region. Do that now.
74567554
if (!Config.isTargetDevice())
7457-
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7458-
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7555+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7556+
OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7557+
HasNowait);
74597558
return Builder.saveIP();
74607559
}
74617560

0 commit comments

Comments
 (0)