Skip to content

Commit 27bc6bd

Browse files
authored
[OMPIRBuilder] Introduce struct to hold default kernel teams/threads (#116050)
This patch introduces the `OpenMPIRBuilder::TargetKernelDefaultAttrs` structure used to simplify passing default and constant values for number of teams and threads, and possibly other target kernel-related information in the future. This is used to forward values passed to `createTarget` to `createTargetInit`, which previously used a default unrelated set of values.
1 parent 0bf1591 commit 27bc6bd

File tree

8 files changed

+106
-78
lines changed

8 files changed

+106
-78
lines changed

clang/lib/CodeGen/CGOpenMPRuntime.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
58805880

58815881
void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
58825882
const OMPExecutableDirective &D, CodeGenFunction &CGF,
5883-
int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
5884-
int32_t &MaxTeamsVal) {
5883+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
5884+
assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
5885+
"invalid default attrs structure");
5886+
int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
5887+
int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
58855888

5886-
getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
5889+
getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
58875890
getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
58885891
/*UpperBoundOnly=*/true);
58895892

@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
59015904
else
59025905
continue;
59035906

5904-
MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
5907+
Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
59055908
if (AttrMaxThreadsVal > 0)
59065909
MaxThreadsVal = MaxThreadsVal > 0
59075910
? std::min(MaxThreadsVal, AttrMaxThreadsVal)
59085911
: AttrMaxThreadsVal;
5909-
MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
5912+
Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
59105913
if (AttrMaxBlocksVal > 0)
59115914
MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
59125915
: AttrMaxBlocksVal;

clang/lib/CodeGen/CGOpenMPRuntime.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,9 @@ class CGOpenMPRuntime {
313313
llvm::OpenMPIRBuilder OMPBuilder;
314314

315315
/// Helper to determine the min/max number of threads/teams for \p D.
316-
void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
317-
CodeGenFunction &CGF,
318-
int32_t &MinThreadsVal,
319-
int32_t &MaxThreadsVal,
320-
int32_t &MinTeamsVal,
321-
int32_t &MaxTeamsVal);
316+
void computeMinAndMaxThreadsAndTeams(
317+
const OMPExecutableDirective &D, CodeGenFunction &CGF,
318+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
322319

323320
/// Helper to emit outlined function for 'target' directive.
324321
/// \param D Directive to emit.

clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -744,14 +744,12 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
744744
void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
745745
CodeGenFunction &CGF,
746746
EntryFunctionState &EST, bool IsSPMD) {
747-
int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
748-
MaxTeamsVal = -1;
749-
computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
750-
MinTeamsVal, MaxTeamsVal);
747+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
748+
Attrs.IsSPMD = IsSPMD;
749+
computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
751750

752751
CGBuilderTy &Bld = CGF.Builder;
753-
Bld.restoreIP(OMPBuilder.createTargetInit(
754-
Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
752+
Bld.restoreIP(OMPBuilder.createTargetInit(Bld, Attrs));
755753
if (!IsSPMD)
756754
emitGenericVarsProlog(CGF, EST.Loc);
757755
}

llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,21 @@ class OpenMPIRBuilder {
22252225
MapNamesArray(MapNamesArray) {}
22262226
};
22272227

2228+
/// Container to pass the default attributes with which a kernel must be
2229+
/// launched, used to set kernel attributes and populate associated static
2230+
/// structures.
2231+
///
2232+
/// For max values, < 0 means unset, == 0 means set but unknown at compile
2233+
/// time. The number of max values will be 1 except for the case where
2234+
/// ompx_bare is set.
2235+
struct TargetKernelDefaultAttrs {
2236+
bool IsSPMD = false;
2237+
SmallVector<int32_t, 3> MaxTeams = {-1};
2238+
int32_t MinTeams = 1;
2239+
SmallVector<int32_t, 3> MaxThreads = {-1};
2240+
int32_t MinThreads = 1;
2241+
};
2242+
22282243
/// Data structure that contains the needed information to construct the
22292244
/// kernel args vector.
22302245
struct TargetKernelArgs {
@@ -2727,16 +2742,11 @@ class OpenMPIRBuilder {
27272742
/// Create a runtime call for kmpc_target_init
27282743
///
27292744
/// \param Loc The insert and source location description.
2730-
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
2731-
/// \param MinThreads Minimal number of threads, or 0.
2732-
/// \param MaxThreads Maximal number of threads, or 0.
2733-
/// \param MinTeams Minimal number of teams, or 0.
2734-
/// \param MaxTeams Maximal number of teams, or 0.
2735-
InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
2736-
int32_t MinThreadsVal = 0,
2737-
int32_t MaxThreadsVal = 0,
2738-
int32_t MinTeamsVal = 0,
2739-
int32_t MaxTeamsVal = 0);
2745+
/// \param Attrs Structure containing the default attributes, including
2746+
/// numbers of threads and teams to launch the kernel with.
2747+
InsertPointTy createTargetInit(
2748+
const LocationDescription &Loc,
2749+
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
27402750

27412751
/// Create a runtime call for kmpc_target_deinit
27422752
///
@@ -2961,8 +2971,8 @@ class OpenMPIRBuilder {
29612971
/// \param CodeGenIP The insertion point where the call to the outlined
29622972
/// function should be emitted.
29632973
/// \param EntryInfo The entry information about the function.
2964-
/// \param NumTeams Number of teams specified in the num_teams clause.
2965-
/// \param NumThreads Number of teams specified in the thread_limit clause.
2974+
/// \param DefaultAttrs Structure containing the default numbers of threads
2975+
/// and teams to launch the kernel with.
29662976
/// \param Inputs The input values to the region that will be passed.
29672977
/// as arguments to the outlined function.
29682978
/// \param BodyGenCB Callback that will generate the region code.
@@ -2975,9 +2985,10 @@ class OpenMPIRBuilder {
29752985
const LocationDescription &Loc, bool IsOffloadEntry,
29762986
OpenMPIRBuilder::InsertPointTy AllocaIP,
29772987
OpenMPIRBuilder::InsertPointTy CodeGenIP,
2978-
TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
2979-
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
2980-
GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
2988+
TargetRegionEntryInfo &EntryInfo,
2989+
const TargetKernelDefaultAttrs &DefaultAttrs,
2990+
SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
2991+
TargetBodyGenCallbackTy BodyGenCB,
29812992
TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
29822993
SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
29832994

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 42 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -6119,19 +6119,22 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
61196119
return Builder.CreateCall(Fn, Args);
61206120
}
61216121

6122-
OpenMPIRBuilder::InsertPointTy
6123-
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6124-
int32_t MinThreadsVal, int32_t MaxThreadsVal,
6125-
int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6122+
OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
6123+
const LocationDescription &Loc,
6124+
const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6125+
assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
6126+
"expected num_threads and num_teams to be specified");
6127+
61266128
if (!updateToLocation(Loc))
61276129
return Loc.IP;
61286130

61296131
uint32_t SrcLocStrSize;
61306132
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
61316133
Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
61326134
Constant *IsSPMDVal = ConstantInt::getSigned(
6133-
Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6134-
Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
6135+
Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136+
Constant *UseGenericStateMachineVal =
6137+
ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
61356138
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
61366139
Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
61376140

@@ -6149,21 +6152,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
61496152

61506153
// Manifest the launch configuration in the metadata matching the kernel
61516154
// environment.
6152-
if (MinTeamsVal > 1 || MaxTeamsVal > 0)
6153-
writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
6155+
if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
6156+
writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
61546157

6155-
// For max values, < 0 means unset, == 0 means set but unknown.
6158+
// If MaxThreads not set, select the maximum between the default workgroup
6159+
// size and the MinThreads value.
6160+
int32_t MaxThreadsVal = Attrs.MaxThreads.front();
61566161
if (MaxThreadsVal < 0)
61576162
MaxThreadsVal = std::max(
6158-
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
6163+
int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
61596164

61606165
if (MaxThreadsVal > 0)
6161-
writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
6166+
writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
61626167

6163-
Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
6168+
Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
61646169
Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
6165-
Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
6166-
Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
6170+
Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
6171+
Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
61676172
Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
61686173
Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
61696174

@@ -6730,8 +6735,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
67306735
}
67316736

67326737
static Expected<Function *> createOutlinedFunction(
6733-
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6734-
SmallVectorImpl<Value *> &Inputs,
6738+
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6739+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6740+
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
67356741
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
67366742
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
67376743
SmallVector<Type *> ParameterTypes;
@@ -6798,7 +6804,7 @@ static Expected<Function *> createOutlinedFunction(
67986804

67996805
// Insert target init call in the device compilation pass.
68006806
if (OMPBuilder.Config.isTargetDevice())
6801-
Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
6807+
Builder.restoreIP(OMPBuilder.createTargetInit(Builder, DefaultAttrs));
68026808

68036809
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
68046810

@@ -6997,16 +7003,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
69977003

69987004
static Error emitTargetOutlinedFunction(
69997005
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7000-
TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
7001-
Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
7006+
TargetRegionEntryInfo &EntryInfo,
7007+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7008+
Function *&OutlinedFn, Constant *&OutlinedFnID,
7009+
SmallVectorImpl<Value *> &Inputs,
70027010
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
70037011
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
70047012

70057013
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7006-
[&OMPBuilder, &Builder, &Inputs, &CBFunc,
7007-
&ArgAccessorFuncCB](StringRef EntryFnName) {
7008-
return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
7009-
CBFunc, ArgAccessorFuncCB);
7014+
[&](StringRef EntryFnName) {
7015+
return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
7016+
EntryFnName, Inputs, CBFunc,
7017+
ArgAccessorFuncCB);
70107018
};
70117019

70127020
return OMPBuilder.emitTargetRegionFunction(
@@ -7302,9 +7310,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
73027310

73037311
static void
73047312
emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7305-
OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7306-
Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
7307-
ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
7313+
OpenMPIRBuilder::InsertPointTy AllocaIP,
7314+
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7315+
Function *OutlinedFn, Constant *OutlinedFnID,
7316+
SmallVectorImpl<Value *> &Args,
73087317
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
73097318
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
73107319
bool HasNoWait = false) {
@@ -7385,9 +7394,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
73857394

73867395
SmallVector<Value *, 3> NumTeamsC;
73877396
SmallVector<Value *, 3> NumThreadsC;
7388-
for (auto V : NumTeams)
7397+
for (auto V : DefaultAttrs.MaxTeams)
73897398
NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
7390-
for (auto V : NumThreads)
7399+
for (auto V : DefaultAttrs.MaxThreads)
73917400
NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
73927401

73937402
unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7428,7 +7437,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
74287437
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74297438
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
74307439
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7431-
ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
7440+
const TargetKernelDefaultAttrs &DefaultAttrs,
74327441
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
74337442
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
74347443
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7445,16 +7454,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
74457454
// the target region itself is generated using the callbacks CBFunc
74467455
// and ArgAccessorFuncCB
74477456
if (Error Err = emitTargetOutlinedFunction(
7448-
*this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
7449-
Args, CBFunc, ArgAccessorFuncCB))
7457+
*this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
7458+
OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
74507459
return Err;
74517460

74527461
// If we are not on the target device, then we need to generate code
74537462
// to make a remote call (offload) to the previously outlined function
74547463
// that represents the target region. Do that now.
74557464
if (!Config.isTargetDevice())
7456-
emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
7457-
NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
7465+
emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7466+
OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
74587467
return Builder.saveIP();
74597468
}
74607469

llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6229,10 +6229,14 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
62296229

62306230
TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
62316231
OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
6232+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6233+
/*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
6234+
/*MinThreads=*/0};
6235+
62326236
ASSERT_EXPECTED_INIT(
62336237
OpenMPIRBuilder::InsertPointTy, AfterIP,
62346238
OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
6235-
Builder.saveIP(), EntryInfo, -1, 0, Inputs,
6239+
Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
62366240
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
62376241
Builder.restoreIP(AfterIP);
62386242
OMPBuilder.finalize();
@@ -6339,13 +6343,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
63396343
F->getEntryBlock().getFirstInsertionPt());
63406344
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
63416345
/*Line=*/3, /*Count=*/0);
6346+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6347+
/*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
6348+
/*MinThreads=*/0};
63426349

63436350
ASSERT_EXPECTED_INIT(
63446351
OpenMPIRBuilder::InsertPointTy, AfterIP,
63456352
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6346-
EntryInfo, /*NumTeams=*/-1,
6347-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6348-
BodyGenCB, SimpleArgAccessorCB));
6353+
EntryInfo, DefaultAttrs, CapturedArgs,
6354+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
63496355
Builder.restoreIP(AfterIP);
63506356

63516357
Builder.CreateRetVoid();
@@ -6496,13 +6502,15 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
64966502
F->getEntryBlock().getFirstInsertionPt());
64976503
TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
64986504
/*Line=*/3, /*Count=*/0);
6505+
OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
6506+
/*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
6507+
/*MinThreads=*/0};
64996508

65006509
ASSERT_EXPECTED_INIT(
65016510
OpenMPIRBuilder::InsertPointTy, AfterIP,
65026511
OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
6503-
EntryInfo, /*NumTeams=*/-1,
6504-
/*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
6505-
BodyGenCB, SimpleArgAccessorCB));
6512+
EntryInfo, DefaultAttrs, CapturedArgs,
6513+
GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
65066514
Builder.restoreIP(AfterIP);
65076515

65086516
Builder.CreateRetVoid();

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4084,9 +4084,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
40844084
if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
40854085
return failure();
40864086

4087-
int32_t defaultValTeams = -1;
4088-
int32_t defaultValThreads = 0;
4089-
40904087
MapInfoData mapData;
40914088
collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
40924089
builder);
@@ -4118,6 +4115,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
41184115
allocaIP, codeGenIP);
41194116
};
41204117

4118+
// TODO: Populate default attributes based on the construct and clauses.
4119+
llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
4120+
/*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
4121+
/*MinThreads=*/0};
4122+
41214123
llvm::SmallVector<llvm::Value *, 4> kernelInput;
41224124
for (size_t i = 0; i < mapVars.size(); ++i) {
41234125
// declare target arguments are not passed to kernels as arguments
@@ -4141,8 +4143,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
41414143
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
41424144
moduleTranslation.getOpenMPBuilder()->createTarget(
41434145
ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
4144-
defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
4145-
argAccessorCB, dds, targetOp.getNowait());
4146+
defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
4147+
targetOp.getNowait());
41464148

41474149
if (failed(handleError(afterIP, opInst)))
41484150
return failure();

0 commit comments

Comments
 (0)