@@ -6109,10 +6109,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6109
6109
return Builder.CreateCall (Fn, Args);
6110
6110
}
6111
6111
6112
- OpenMPIRBuilder::InsertPointTy
6113
- OpenMPIRBuilder::createTargetInit (const LocationDescription &Loc, bool IsSPMD,
6114
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
6115
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6112
+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit (
6113
+ const LocationDescription &Loc, bool IsSPMD,
6114
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6115
+ assert (!Attrs.MaxThreads .empty () && !Attrs.MaxTeams .empty () &&
6116
+ " expected num_threads and num_teams to be specified" );
6117
+
6116
6118
if (!updateToLocation (Loc))
6117
6119
return Loc.IP ;
6118
6120
@@ -6139,21 +6141,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6139
6141
6140
6142
// Manifest the launch configuration in the metadata matching the kernel
6141
6143
// environment.
6142
- if (MinTeamsVal > 1 || MaxTeamsVal > 0 )
6143
- writeTeamsForKernel (T, *Kernel, MinTeamsVal, MaxTeamsVal );
6144
+ if (Attrs. MinTeams > 1 || Attrs. MaxTeams . front () > 0 )
6145
+ writeTeamsForKernel (T, *Kernel, Attrs. MinTeams , Attrs. MaxTeams . front () );
6144
6146
6145
- // For max values, < 0 means unset, == 0 means set but unknown.
6147
+ // If MaxThreads not set, select the maximum between the default workgroup
6148
+ // size and the MinThreads value.
6149
+ int32_t MaxThreadsVal = Attrs.MaxThreads .front ();
6146
6150
if (MaxThreadsVal < 0 )
6147
6151
MaxThreadsVal = std::max (
6148
- int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), MinThreadsVal );
6152
+ int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), Attrs. MinThreads );
6149
6153
6150
6154
if (MaxThreadsVal > 0 )
6151
- writeThreadBoundsForKernel (T, *Kernel, MinThreadsVal , MaxThreadsVal);
6155
+ writeThreadBoundsForKernel (T, *Kernel, Attrs. MinThreads , MaxThreadsVal);
6152
6156
6153
- Constant *MinThreads = ConstantInt::getSigned (Int32, MinThreadsVal );
6157
+ Constant *MinThreads = ConstantInt::getSigned (Int32, Attrs. MinThreads );
6154
6158
Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
6155
- Constant *MinTeams = ConstantInt::getSigned (Int32, MinTeamsVal );
6156
- Constant *MaxTeams = ConstantInt::getSigned (Int32, MaxTeamsVal );
6159
+ Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs. MinTeams );
6160
+ Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs. MaxTeams . front () );
6157
6161
Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
6158
6162
Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
6159
6163
@@ -6724,8 +6728,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6724
6728
}
6725
6729
6726
6730
static Expected<Function *> createOutlinedFunction (
6727
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6728
- SmallVectorImpl<Value *> &Inputs,
6731
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6732
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6733
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6729
6734
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6730
6735
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6731
6736
SmallVector<Type *> ParameterTypes;
@@ -6792,7 +6797,8 @@ static Expected<Function *> createOutlinedFunction(
6792
6797
6793
6798
// Insert target init call in the device compilation pass.
6794
6799
if (OMPBuilder.Config .isTargetDevice ())
6795
- Builder.restoreIP (OMPBuilder.createTargetInit (Builder, /* IsSPMD*/ false ));
6800
+ Builder.restoreIP (
6801
+ OMPBuilder.createTargetInit (Builder, /* IsSPMD=*/ false , DefaultAttrs));
6796
6802
6797
6803
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6798
6804
@@ -6989,16 +6995,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6989
6995
6990
6996
static Error emitTargetOutlinedFunction (
6991
6997
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6992
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
6993
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
6998
+ TargetRegionEntryInfo &EntryInfo,
6999
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7000
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
7001
+ SmallVectorImpl<Value *> &Inputs,
6994
7002
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6995
7003
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6996
7004
6997
7005
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
6998
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
6999
- &ArgAccessorFuncCB](StringRef EntryFnName) {
7000
- return createOutlinedFunction (OMPBuilder, Builder, EntryFnName, Inputs,
7001
- CBFunc, ArgAccessorFuncCB);
7006
+ [&](StringRef EntryFnName) {
7007
+ return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7008
+ EntryFnName, Inputs, CBFunc ,
7009
+ ArgAccessorFuncCB);
7002
7010
};
7003
7011
7004
7012
return OMPBuilder.emitTargetRegionFunction (
@@ -7294,9 +7302,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7294
7302
7295
7303
static void
7296
7304
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7297
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7298
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7299
- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7305
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
7306
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7307
+ Function *OutlinedFn, Constant *OutlinedFnID,
7308
+ SmallVectorImpl<Value *> &Args,
7300
7309
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7301
7310
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7302
7311
bool HasNoWait = false ) {
@@ -7377,9 +7386,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7377
7386
7378
7387
SmallVector<Value *, 3 > NumTeamsC;
7379
7388
SmallVector<Value *, 3 > NumThreadsC;
7380
- for (auto V : NumTeams )
7389
+ for (auto V : DefaultAttrs. MaxTeams )
7381
7390
NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7382
- for (auto V : NumThreads )
7391
+ for (auto V : DefaultAttrs. MaxThreads )
7383
7392
NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7384
7393
7385
7394
unsigned NumTargetItems = Info.NumberOfPtrs ;
@@ -7420,7 +7429,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7420
7429
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7421
7430
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7422
7431
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7423
- ArrayRef< int32_t > NumTeams, ArrayRef< int32_t > NumThreads ,
7432
+ const TargetKernelDefaultAttrs &DefaultAttrs ,
7424
7433
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7425
7434
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7426
7435
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7437,16 +7446,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7437
7446
// the target region itself is generated using the callbacks CBFunc
7438
7447
// and ArgAccessorFuncCB
7439
7448
if (Error Err = emitTargetOutlinedFunction (
7440
- *this , Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID ,
7441
- Args, CBFunc, ArgAccessorFuncCB))
7449
+ *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7450
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7442
7451
return Err;
7443
7452
7444
7453
// If we are not on the target device, then we need to generate code
7445
7454
// to make a remote call (offload) to the previously outlined function
7446
7455
// that represents the target region. Do that now.
7447
7456
if (!Config.isTargetDevice ())
7448
- emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams ,
7449
- NumThreads , Args, GenMapInfoCB, Dependencies, HasNowait);
7457
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn ,
7458
+ OutlinedFnID , Args, GenMapInfoCB, Dependencies, HasNowait);
7450
7459
return Builder.saveIP ();
7451
7460
}
7452
7461
0 commit comments