-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode #116051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode #116051
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-mlir-llvm Author: Sergio Afonso (skatrak) ChangesThis patch introduces a Additionally, Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+ /// must be defined in the host prior to the call to the kernel launch OpenMP
+ /// RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is a
+ /// target SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the target SPMD kernel or null if it is a
+ /// generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
+ /// \param IsSPMD whether it is a target SPMD kernel.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+ Module &M) {
+ if (List.empty())
+ return;
+
+ Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.reserve(List.size());
+ for (auto Item : List)
+ UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*Item), PtrTy));
+
+ ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+ auto *GV =
+ new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+ llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+ auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+ auto *GVMode = new llvm::GlobalVariable(
+ OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+ Twine(FunctionName, "_exec_mode"));
+ GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+ LLVMCompilerUsed.emplace_back(GVMode);
+}
+
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ // Forward target-cpu and target-features function attributes from the
+ // original function to the new outlined function.
+ Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+ auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+ if (TargetCpuAttr.isStringAttribute())
+ Func->addFnAttr(TargetCpuAttr);
+
+ auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+ if (TargetFeaturesAttr.isStringAttribute())
+ Func->addFnAttr(TargetFeaturesAttr);
+
+ if (OMPBuilder.Config.isTargetDevice()) {
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+ emitExecutionMode(OMPBuilder, Builder, FuncName,
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+ : OMP_TGT_EXEC_MODE_GENERIC,
+ LLVMCompilerUsed);
+ emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo,
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7304,6 +7360,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
+ assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+ "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
OMPBuilder.setConfig(Config);
F->setName("func");
+ F->addFnAttr("target-cpu", "x86-64");
+ F->addFnAttr("target-features", "+mmx,+sse");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName2 = OutlinedFunc->getName();
EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+ F->getFnAttribute("target-cpu"));
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+ F->getFnAttribute("target-features"));
+
EXPECT_FALSE(verifyModule(*M, &errs()));
}
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OMPBuilder.initialize();
F->setName("func");
+ F->addFnAttr("target-cpu", "gfx90a");
+ F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
Function *OutlinedFn = TargetStore->getFunction();
EXPECT_NE(F, OutlinedFn);
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+ ...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) ChangesThis patch introduces a Additionally, Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff 4 Files Affected:
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
int32_t MinThreads = 1;
};
+ /// Container to pass LLVM IR runtime values or constants related to the
+ /// number of teams and threads with which the kernel must be launched, as
+ /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+ /// must be defined in the host prior to the call to the kernel launch OpenMP
+ /// RTL function.
+ struct TargetKernelRuntimeAttrs {
+ SmallVector<Value *, 3> MaxTeams = {nullptr};
+ Value *MinTeams = nullptr;
+ SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+ SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+ /// 'parallel' construct 'num_threads' clause value, if present and it is a
+ /// target SPMD kernel.
+ Value *MaxThreads = nullptr;
+
+ /// Total number of iterations of the target SPMD kernel or null if it is a
+ /// generic kernel.
+ Value *LoopTripCount = nullptr;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
///
/// \param Loc where the target data construct was encountered.
/// \param IsOffloadEntry whether it is an offload entry.
+ /// \param IsSPMD whether it is a target SPMD kernel.
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
/// \param DefaultAttrs Structure containing the default numbers of threads
/// and teams to launch the kernel with.
+ /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
// dependency information as passed in the depend clause
// \param HasNowait Whether the target construct has a `nowait` clause or not.
InsertPointOrErrorTy createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
}
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+ Module &M) {
+ if (List.empty())
+ return;
+
+ Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+ // Convert List to what ConstantArray needs.
+ SmallVector<Constant *, 8> UsedArray;
+ UsedArray.reserve(List.size());
+ for (auto Item : List)
+ UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+ cast<Constant>(&*Item), PtrTy));
+
+ ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+ auto *GV =
+ new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+ llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+ GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+ auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+ auto *GVMode = new llvm::GlobalVariable(
+ OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+ Twine(FunctionName, "_exec_mode"));
+ GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+ LLVMCompilerUsed.emplace_back(GVMode);
+}
+
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
auto Func =
Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
+ // Forward target-cpu and target-features function attributes from the
+ // original function to the new outlined function.
+ Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+ auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+ if (TargetCpuAttr.isStringAttribute())
+ Func->addFnAttr(TargetCpuAttr);
+
+ auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+ if (TargetFeaturesAttr.isStringAttribute())
+ Func->addFnAttr(TargetFeaturesAttr);
+
+ if (OMPBuilder.Config.isTargetDevice()) {
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+ emitExecutionMode(OMPBuilder, Builder, FuncName,
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+ : OMP_TGT_EXEC_MODE_GENERIC,
+ LLVMCompilerUsed);
+ emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+ }
+
// Save insert point.
IRBuilder<>::InsertPointGuard IPG(Builder);
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
Builder.restoreIP(
- OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+ OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo,
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
Function *&OutlinedFn, Constant *&OutlinedFnID,
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
[&](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
EntryFnName, Inputs, CBFunc,
ArgAccessorFuncCB);
};
@@ -7304,6 +7360,7 @@ static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointTy AllocaIP,
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
Function *OutlinedFn, Constant *OutlinedFnID,
SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
/*ForEndCall=*/false);
SmallVector<Value *, 3> NumTeamsC;
+ for (auto [DefaultVal, RuntimeVal] :
+ zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+ NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+ if (Clause)
+ Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+ /*isSigned=*/false);
+ return Clause;
+ };
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+ if (Clause)
+ Result = Result
+ ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+ Result, Clause)
+ : Clause;
+ };
+
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : DefaultAttrs.MaxTeams)
- NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : DefaultAttrs.MaxThreads)
- NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+ ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+ : nullptr;
+
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+ RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+ Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+ CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+ CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+ NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+ }
unsigned NumTargetItems = Info.NumberOfPtrs;
// TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
llvm::omp::IdentFlag(0), 0);
- // TODO: Use correct NumIterations
- Value *NumIterations = Builder.getInt64(0);
+
+ Value *TripCount = RuntimeAttrs.LoopTripCount
+ ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+ Builder.getInt64Ty(),
+ /*isSigned=*/false)
+ : Builder.getInt64(0);
+
// TODO: Use correct DynCGGroupMem
Value *DynCGGroupMem = Builder.getInt32(0);
- KArgs = OpenMPIRBuilder::TargetKernelArgs(
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
- DynCGGroupMem, HasNoWait);
+ KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+ NumTeamsC, NumThreadsC,
+ DynCGGroupMem, HasNoWait);
// The presence of certain clauses on the target directive require the
// explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
}
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+ TargetRegionEntryInfo &EntryInfo,
const TargetKernelDefaultAttrs &DefaultAttrs,
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies, bool HasNowait) {
+ assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+ "trip count not expected if IsSPMD=false");
if (!updateToLocation(Loc))
return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+ HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
OMPBuilder.setConfig(Config);
F->setName("func");
+ F->addFnAttr("target-cpu", "x86-64");
+ F->addFnAttr("target-features", "+mmx,+sse");
IRBuilder<> Builder(BB);
- auto Int32Ty = Builder.getInt32Ty();
+ auto *Int32Ty = Builder.getInt32Ty();
AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
- /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
- Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
- GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+ RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+ RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+ RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName = KernelLaunchFunc->getName();
EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+ // Check num_teams and num_threads in call arguments
+ EXPECT_TRUE(Call->arg_size() >= 4);
+ Value *NumTeamsArg = Call->getArgOperand(2);
+ EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+ EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+ Value *NumThreadsArg = Call->getArgOperand(3);
+ EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+ EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+ // Check num_teams and num_threads kernel arguments (use number 5 starting
+ // from the end and counting the call to __tgt_target_kernel as the first use)
+ Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+ EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+ Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+ Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+ Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+ auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+ EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+ EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+ Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+ EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+ Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+ EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+ Value *NumThreadsStoreArg =
+ cast<StoreInst>(NumThreadsStore)->getValueOperand();
+ EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+ auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+ EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+ EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+ EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
// Check the fallback call
BasicBlock *FallbackBlock = Branch->getSuccessor(0);
Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
StringRef FunctionName2 = OutlinedFunc->getName();
EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+ F->getFnAttribute("target-cpu"));
+ EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+ F->getFnAttribute("target-features"));
+
EXPECT_FALSE(verifyModule(*M, &errs()));
}
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OMPBuilder.initialize();
F->setName("func");
+ F->addFnAttr("target-cpu", "gfx90a");
+ F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
IRBuilder<> Builder(BB);
OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
- CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+ EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+ BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
Function *OutlinedFn = TargetStore->getFunction();
EXPECT_NE(F, OutlinedFn);
+ // Check that target-cpu and target-features were propagated to the outlined
+ // function
+ EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+ ...
[truncated]
|
Buildbot failure seems to be some temporary issue unrelated to the PR. |
cc5c5cc
to
e2b3ac4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @jdoerfert for the review. Your comments should be addressed now.
1fcfe48
to
e3cdc93
Compare
e2b3ac4
to
b1e4eb5
Compare
@jdoerfert, can you check whether your concerns have been addressed? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
e3cdc93
to
45c6667
Compare
b1e4eb5
to
76b2b9f
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
76b2b9f
to
47a6495
Compare
The PR stack should be almost ready to be merged in the next few days (buildbot failures are unrelated). If there are any remaining blockers from your side, let me know @jdoerfert. |
45c6667
to
219d430
Compare
47a6495
to
1f5cd91
Compare
…mode This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call. Additionally, kernel type information is used to influence target device code generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more granularity.
1f5cd91
to
0c19f71
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/16/builds/11964 Here is the relevant piece of the build log for the reference
|
This patch introduces a
TargetKernelRuntimeAttrs
structure to hold host-evaluatednum_teams
,thread_limit
,num_threads
and trip count values passed to the runtime kernel offloading call.Additionally, kernel type information is used to influence target device code generation and the
IsSPMD
flag is replaced byExecFlags
, which provide more granularity.