-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[OMPIRBuilder] Introduce struct to hold default kernel teams/threads #116050
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-clang Author: Sergio Afonso (skatrak) ChangesThis patch introduces the This is used to forward values passed to Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff 8 Files Affected:
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
const OMPExecutableDirective &D, CodeGenFunction &CGF,
- int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal) {
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+ "invalid default attrs structure");
+ int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+ int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
- getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+ getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
/*UpperBoundOnly=*/true);
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
else
continue;
- MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+ Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
if (AttrMaxThreadsVal > 0)
MaxThreadsVal = MaxThreadsVal > 0
? std::min(MaxThreadsVal, AttrMaxThreadsVal)
: AttrMaxThreadsVal;
- MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+ Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
if (AttrMaxBlocksVal > 0)
MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
: AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
llvm::OpenMPIRBuilder OMPBuilder;
/// Helper to determine the min/max number of threads/teams for \p D.
- void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
- CodeGenFunction &CGF,
- int32_t &MinThreadsVal,
- int32_t &MaxThreadsVal,
- int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal);
+ void computeMinAndMaxThreadsAndTeams(
+ const OMPExecutableDirective &D, CodeGenFunction &CGF,
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Helper to emit outlined function for 'target' directive.
/// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
CodeGenFunction &CGF,
EntryFunctionState &EST, bool IsSPMD) {
- int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
- MaxTeamsVal = -1;
- computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
- MinTeamsVal, MaxTeamsVal);
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+ computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
CGBuilderTy &Bld = CGF.Builder;
- Bld.restoreIP(OMPBuilder.createTargetInit(
- Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+ Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
if (!IsSPMD)
emitGenericVarsProlog(CGF, EST.Loc);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
MapNamesArray(MapNamesArray) {}
};
+ /// Container to pass the default attributes with which a kernel must be
+ /// launched, used to set kernel attributes and populate associated static
+ /// structures.
+ ///
+ /// For max values, < 0 means unset, == 0 means set but unknown at compile
+ /// time. The number of max values will be 1 except for the case where
+ /// ompx_bare is set.
+ struct TargetKernelDefaultAttrs {
+ SmallVector<int32_t, 3> MaxTeams = {-1};
+ int32_t MinTeams = 1;
+ SmallVector<int32_t, 3> MaxThreads = {-1};
+ int32_t MinThreads = 1;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
- /// \param MinThreads Minimal number of threads, or 0.
- /// \param MaxThreads Maximal number of threads, or 0.
- /// \param MinTeams Minimal number of teams, or 0.
- /// \param MaxTeams Maximal number of teams, or 0.
- InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal = 0,
- int32_t MaxThreadsVal = 0,
- int32_t MinTeamsVal = 0,
- int32_t MaxTeamsVal = 0);
+ /// \param Attrs Structure containing the default numbers of threads and teams
+ /// to launch the kernel with.
+ InsertPointTy createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Create a runtime call for kmpc_target_deinit
///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
- /// \param NumTeams Number of teams specified in the num_teams clause.
- /// \param NumThreads Number of teams specified in the thread_limit clause.
+ /// \param DefaultAttrs Structure containing the default numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+ TargetRegionEntryInfo &EntryInfo,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
+ SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
return Builder.CreateCall(Fn, Args);
}
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+ "expected num_threads and num_teams to be specified");
+
if (!updateToLocation(Loc))
return Loc.IP;
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
// Manifest the launch configuration in the metadata matching the kernel
// environment.
- if (MinTeamsVal > 1 || MaxTeamsVal > 0)
- writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+ if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+ writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
- // For max values, < 0 means unset, == 0 means set but unknown.
+ // If MaxThreads not set, select the maximum between the default workgroup
+ // size and the MinThreads value.
+ int32_t MaxThreadsVal = Attrs.MaxThreads.front();
if (MaxThreadsVal < 0)
MaxThreadsVal = std::max(
- int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+ int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
if (MaxThreadsVal > 0)
- writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+ writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
- Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+ Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
- Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
- Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+ Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+ Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
}
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
- SmallVectorImpl<Value *> &Inputs,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
- Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+ Builder.restoreIP(
+ OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+ TargetRegionEntryInfo &EntryInfo,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
+ SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
- &ArgAccessorFuncCB](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
- CBFunc, ArgAccessorFuncCB);
+ [&](StringRef EntryFnName) {
+ return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ EntryFnName, Inputs, CBFunc,
+ ArgAccessorFuncCB);
};
return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
- Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *OutlinedFn, Constant *OutlinedFnID,
+ SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
SmallVector<Value *, 3> NumTeamsC;
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
+ for (auto V : DefaultAttrs.MaxTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
+ for (auto V : DefaultAttrs.MaxThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
- ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
- Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
- EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
return failure();
- int32_t defaultValTeams = -1;
- int32_t defaultValThreads = 0;
-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
+ // TODO: Populate default attributes based on the construct and clauses.
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
// declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, alloca...
[truncated]
|
@llvm/pr-subscribers-flang-openmp Author: Sergio Afonso (skatrak) ChangesThis patch introduces the This is used to forward values passed to Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff 8 Files Affected:
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
const OMPExecutableDirective &D, CodeGenFunction &CGF,
- int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal) {
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+ "invalid default attrs structure");
+ int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+ int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
- getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+ getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
/*UpperBoundOnly=*/true);
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
else
continue;
- MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+ Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
if (AttrMaxThreadsVal > 0)
MaxThreadsVal = MaxThreadsVal > 0
? std::min(MaxThreadsVal, AttrMaxThreadsVal)
: AttrMaxThreadsVal;
- MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+ Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
if (AttrMaxBlocksVal > 0)
MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
: AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
llvm::OpenMPIRBuilder OMPBuilder;
/// Helper to determine the min/max number of threads/teams for \p D.
- void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
- CodeGenFunction &CGF,
- int32_t &MinThreadsVal,
- int32_t &MaxThreadsVal,
- int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal);
+ void computeMinAndMaxThreadsAndTeams(
+ const OMPExecutableDirective &D, CodeGenFunction &CGF,
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Helper to emit outlined function for 'target' directive.
/// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
CodeGenFunction &CGF,
EntryFunctionState &EST, bool IsSPMD) {
- int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
- MaxTeamsVal = -1;
- computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
- MinTeamsVal, MaxTeamsVal);
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+ computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
CGBuilderTy &Bld = CGF.Builder;
- Bld.restoreIP(OMPBuilder.createTargetInit(
- Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+ Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
if (!IsSPMD)
emitGenericVarsProlog(CGF, EST.Loc);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
MapNamesArray(MapNamesArray) {}
};
+ /// Container to pass the default attributes with which a kernel must be
+ /// launched, used to set kernel attributes and populate associated static
+ /// structures.
+ ///
+ /// For max values, < 0 means unset, == 0 means set but unknown at compile
+ /// time. The number of max values will be 1 except for the case where
+ /// ompx_bare is set.
+ struct TargetKernelDefaultAttrs {
+ SmallVector<int32_t, 3> MaxTeams = {-1};
+ int32_t MinTeams = 1;
+ SmallVector<int32_t, 3> MaxThreads = {-1};
+ int32_t MinThreads = 1;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
- /// \param MinThreads Minimal number of threads, or 0.
- /// \param MaxThreads Maximal number of threads, or 0.
- /// \param MinTeams Minimal number of teams, or 0.
- /// \param MaxTeams Maximal number of teams, or 0.
- InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal = 0,
- int32_t MaxThreadsVal = 0,
- int32_t MinTeamsVal = 0,
- int32_t MaxTeamsVal = 0);
+ /// \param Attrs Structure containing the default numbers of threads and teams
+ /// to launch the kernel with.
+ InsertPointTy createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Create a runtime call for kmpc_target_deinit
///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
- /// \param NumTeams Number of teams specified in the num_teams clause.
- /// \param NumThreads Number of teams specified in the thread_limit clause.
+ /// \param DefaultAttrs Structure containing the default numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+ TargetRegionEntryInfo &EntryInfo,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
+ SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
return Builder.CreateCall(Fn, Args);
}
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+ "expected num_threads and num_teams to be specified");
+
if (!updateToLocation(Loc))
return Loc.IP;
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
// Manifest the launch configuration in the metadata matching the kernel
// environment.
- if (MinTeamsVal > 1 || MaxTeamsVal > 0)
- writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+ if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+ writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
- // For max values, < 0 means unset, == 0 means set but unknown.
+ // If MaxThreads not set, select the maximum between the default workgroup
+ // size and the MinThreads value.
+ int32_t MaxThreadsVal = Attrs.MaxThreads.front();
if (MaxThreadsVal < 0)
MaxThreadsVal = std::max(
- int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+ int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
if (MaxThreadsVal > 0)
- writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+ writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
- Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+ Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
- Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
- Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+ Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+ Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
}
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
- SmallVectorImpl<Value *> &Inputs,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
- Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+ Builder.restoreIP(
+ OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+ TargetRegionEntryInfo &EntryInfo,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
+ SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
- &ArgAccessorFuncCB](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
- CBFunc, ArgAccessorFuncCB);
+ [&](StringRef EntryFnName) {
+ return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ EntryFnName, Inputs, CBFunc,
+ ArgAccessorFuncCB);
};
return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
- Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *OutlinedFn, Constant *OutlinedFnID,
+ SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
SmallVector<Value *, 3> NumTeamsC;
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
+ for (auto V : DefaultAttrs.MaxTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
+ for (auto V : DefaultAttrs.MaxThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
- ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
- Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
- EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
return failure();
- int32_t defaultValTeams = -1;
- int32_t defaultValThreads = 0;
-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
+ // TODO: Populate default attributes based on the construct and clauses.
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
// declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, alloca...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Sergio Afonso (skatrak) ChangesThis patch introduces the This is used to forward values passed to Patch is 21.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116050.diff 8 Files Affected:
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index d714af035d21a2..0f7a1166227476 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
const OMPExecutableDirective &D, CodeGenFunction &CGF,
- int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal) {
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+ "invalid default attrs structure");
+ int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+ int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
- getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+ getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
/*UpperBoundOnly=*/true);
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
else
continue;
- MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+ Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
if (AttrMaxThreadsVal > 0)
MaxThreadsVal = MaxThreadsVal > 0
? std::min(MaxThreadsVal, AttrMaxThreadsVal)
: AttrMaxThreadsVal;
- MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+ Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
if (AttrMaxBlocksVal > 0)
MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
: AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 5e7715743afb58..003395e7f17ded 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -312,12 +312,9 @@ class CGOpenMPRuntime {
llvm::OpenMPIRBuilder OMPBuilder;
/// Helper to determine the min/max number of threads/teams for \p D.
- void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
- CodeGenFunction &CGF,
- int32_t &MinThreadsVal,
- int32_t &MaxThreadsVal,
- int32_t &MinTeamsVal,
- int32_t &MaxTeamsVal);
+ void computeMinAndMaxThreadsAndTeams(
+ const OMPExecutableDirective &D, CodeGenFunction &CGF,
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Helper to emit outlined function for 'target' directive.
/// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 43dc0e62284602..96f8d6c5c08e56 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -745,14 +745,11 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
CodeGenFunction &CGF,
EntryFunctionState &EST, bool IsSPMD) {
- int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
- MaxTeamsVal = -1;
- computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
- MinTeamsVal, MaxTeamsVal);
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+ computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
CGBuilderTy &Bld = CGF.Builder;
- Bld.restoreIP(OMPBuilder.createTargetInit(
- Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+ Bld.restoreIP(OMPBuilder.createTargetInit(Bld, IsSPMD, Attrs));
if (!IsSPMD)
emitGenericVarsProlog(CGF, EST.Loc);
}
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 3afb9d84278e81..da450ef5adbc14 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2223,6 +2223,20 @@ class OpenMPIRBuilder {
MapNamesArray(MapNamesArray) {}
};
+ /// Container to pass the default attributes with which a kernel must be
+ /// launched, used to set kernel attributes and populate associated static
+ /// structures.
+ ///
+ /// For max values, < 0 means unset, == 0 means set but unknown at compile
+ /// time. The number of max values will be 1 except for the case where
+ /// ompx_bare is set.
+ struct TargetKernelDefaultAttrs {
+ SmallVector<int32_t, 3> MaxTeams = {-1};
+ int32_t MinTeams = 1;
+ SmallVector<int32_t, 3> MaxThreads = {-1};
+ int32_t MinThreads = 1;
+ };
+
/// Data structure that contains the needed information to construct the
/// kernel args vector.
struct TargetKernelArgs {
@@ -2726,15 +2740,11 @@ class OpenMPIRBuilder {
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
- /// \param MinThreads Minimal number of threads, or 0.
- /// \param MaxThreads Maximal number of threads, or 0.
- /// \param MinTeams Minimal number of teams, or 0.
- /// \param MaxTeams Maximal number of teams, or 0.
- InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal = 0,
- int32_t MaxThreadsVal = 0,
- int32_t MinTeamsVal = 0,
- int32_t MaxTeamsVal = 0);
+ /// \param Attrs Structure containing the default numbers of threads and teams
+ /// to launch the kernel with.
+ InsertPointTy createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
/// Create a runtime call for kmpc_target_deinit
///
@@ -2898,8 +2908,8 @@ class OpenMPIRBuilder {
/// \param CodeGenIP The insertion point where the call to the outlined
/// function should be emitted.
/// \param EntryInfo The entry information about the function.
- /// \param NumTeams Number of teams specified in the num_teams clause.
- /// \param NumThreads Number of teams specified in the thread_limit clause.
+ /// \param DefaultAttrs Structure containing the default numbers of threads
+ /// and teams to launch the kernel with.
/// \param Inputs The input values to the region that will be passed.
/// as arguments to the outlined function.
/// \param BodyGenCB Callback that will generate the region code.
@@ -2912,9 +2922,10 @@ class OpenMPIRBuilder {
const LocationDescription &Loc, bool IsOffloadEntry,
OpenMPIRBuilder::InsertPointTy AllocaIP,
OpenMPIRBuilder::InsertPointTy CodeGenIP,
- TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
- GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+ TargetRegionEntryInfo &EntryInfo,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
+ SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+ TargetBodyGenCallbackTy BodyGenCB,
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index d2e4dc1c85dfd2..302d363965c940 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
return Builder.CreateCall(Fn, Args);
}
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+ const LocationDescription &Loc, bool IsSPMD,
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+ assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+ "expected num_threads and num_teams to be specified");
+
if (!updateToLocation(Loc))
return Loc.IP;
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
// Manifest the launch configuration in the metadata matching the kernel
// environment.
- if (MinTeamsVal > 1 || MaxTeamsVal > 0)
- writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+ if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+ writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
- // For max values, < 0 means unset, == 0 means set but unknown.
+ // If MaxThreads not set, select the maximum between the default workgroup
+ // size and the MinThreads value.
+ int32_t MaxThreadsVal = Attrs.MaxThreads.front();
if (MaxThreadsVal < 0)
MaxThreadsVal = std::max(
- int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+ int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
if (MaxThreadsVal > 0)
- writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+ writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
- Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+ Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
- Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
- Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+ Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+ Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
}
static Expected<Function *> createOutlinedFunction(
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
- SmallVectorImpl<Value *> &Inputs,
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
// Insert target init call in the device compilation pass.
if (OMPBuilder.Config.isTargetDevice())
- Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+ Builder.restoreIP(
+ OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
static Error emitTargetOutlinedFunction(
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+ TargetRegionEntryInfo &EntryInfo,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
+ SmallVectorImpl<Value *> &Inputs,
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
- &ArgAccessorFuncCB](StringRef EntryFnName) {
- return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
- CBFunc, ArgAccessorFuncCB);
+ [&](StringRef EntryFnName) {
+ return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+ EntryFnName, Inputs, CBFunc,
+ ArgAccessorFuncCB);
};
return OMPBuilder.emitTargetRegionFunction(
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
static void
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
- Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
- ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+ Function *OutlinedFn, Constant *OutlinedFnID,
+ SmallVectorImpl<Value *> &Args,
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
bool HasNoWait = false) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
SmallVector<Value *, 3> NumTeamsC;
SmallVector<Value *, 3> NumThreadsC;
- for (auto V : NumTeams)
+ for (auto V : DefaultAttrs.MaxTeams)
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
- for (auto V : NumThreads)
+ for (auto V : DefaultAttrs.MaxThreads)
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
- ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+ const TargetKernelDefaultAttrs &DefaultAttrs,
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
// the target region itself is generated using the callbacks CBFunc
// and ArgAccessorFuncCB
if (Error Err = emitTargetOutlinedFunction(
- *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
- Args, CBFunc, ArgAccessorFuncCB))
+ *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
return Err;
// If we are not on the target device, then we need to generate code
// to make a remote call (offload) to the previously outlined function
// that represents the target region. Do that now.
if (!Config.isTargetDevice())
- emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
- NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+ emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+ OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
return Builder.saveIP();
}
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 630cd03c688012..b0688d6215e42d 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6182,9 +6182,12 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
- OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(),
- EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
+ OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+ Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
+ GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
OMPBuilder.finalize();
@@ -6292,11 +6295,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
@@ -6443,11 +6446,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
/*Line=*/3, /*Count=*/0);
- OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
- OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
- EntryInfo, /*NumTeams=*/-1,
- /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
- BodyGenCB, SimpleArgAccessorCB);
+ OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+ OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+ Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
+ CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
assert(AfterIP && "unexpected error");
Builder.restoreIP(*AfterIP);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index cbcbeea4ab9225..d3c3839accb7e7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3902,9 +3902,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
return failure();
- int32_t defaultValTeams = -1;
- int32_t defaultValThreads = 0;
-
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
@@ -3939,6 +3936,10 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
allocaIP, codeGenIP);
};
+ // TODO: Populate default attributes based on the construct and clauses.
+ llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+ /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
llvm::SmallVector<llvm::Value *, 4> kernelInput;
for (size_t i = 0; i < mapVars.size(); ++i) {
// declare target arguments are not passed to kernels as arguments
@@ -3957,8 +3958,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
moduleTranslation.getOpenMPBuilder()->createTarget(
ompLoc, isOffloadEntry, alloca...
[truncated]
|
26fbb25
to
27ffa9f
Compare
1fcfe48
to
e3cdc93
Compare
Ping for reviews! |
OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(), Builder.saveIP(), | ||
EntryInfo, -1, 0, Inputs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB); | ||
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = { | ||
/*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This set of values is used in multiple locations to "default" construct TargetKernelDefaultAttrs
, would it make sense to have this set of values as default values in the struct? I might be missing why we need the current default struct values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults in the new struct represent basically what you would expect: max values representing "unset" (since these can be either unset (<0), runtime-evaluated (0) or constant (>0)) and min values set to 1. I believe that set of defaults makes sense, and it matches what clang set the corresponding attributes initially too.
As for not overriding the defaults in these tests, MaxThreads < 0
causes the OMPIRBuilder to query the default grid size based on the target triple, whereas 0 won't. Querying that triggers an assert if the triple is not one of the supported offloading targets, so at least that one attribute can't be left unchanged unless we change the target triple of the OMPIRBuilder too. But, more generally, I think there is nothing in this PR that causes a need to update these tests, so I just set all of the values to what they already were before the struct was introduced rather than adapting them to its defaults.
I hope that makes sense to you, but let me know if you don't agree.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the clarification.
27ffa9f
to
bd7fa37
Compare
e3cdc93
to
45c6667
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
bd7fa37
to
5f57b94
Compare
This patch introduces the `OpenMPIRBuilder::TargetKernelDefaultAttrs` structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future. This is used to forward values passed to `createTarget` to `createTargetInit`, which previously used a default unrelated set of values.
45c6667
to
219d430
Compare
This patch introduces the
OpenMPIRBuilder::TargetKernelDefaultAttrs
structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future.This is used to forward values passed to
createTarget
tocreateTargetInit
, which previously used a default unrelated set of values.