@@ -6119,19 +6119,22 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
6119
6119
return Builder.CreateCall (Fn, Args);
6120
6120
}
6121
6121
6122
- OpenMPIRBuilder::InsertPointTy
6123
- OpenMPIRBuilder::createTargetInit (const LocationDescription &Loc, bool IsSPMD,
6124
- int32_t MinThreadsVal, int32_t MaxThreadsVal,
6125
- int32_t MinTeamsVal, int32_t MaxTeamsVal) {
6122
+ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit (
6123
+ const LocationDescription &Loc,
6124
+ const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
6125
+ assert (!Attrs.MaxThreads .empty () && !Attrs.MaxTeams .empty () &&
6126
+ " expected num_threads and num_teams to be specified" );
6127
+
6126
6128
if (!updateToLocation (Loc))
6127
6129
return Loc.IP ;
6128
6130
6129
6131
uint32_t SrcLocStrSize;
6130
6132
Constant *SrcLocStr = getOrCreateSrcLocStr (Loc, SrcLocStrSize);
6131
6133
Constant *Ident = getOrCreateIdent (SrcLocStr, SrcLocStrSize);
6132
6134
Constant *IsSPMDVal = ConstantInt::getSigned (
6133
- Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6134
- Constant *UseGenericStateMachineVal = ConstantInt::getSigned (Int8, !IsSPMD);
6135
+ Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
6136
+ Constant *UseGenericStateMachineVal =
6137
+ ConstantInt::getSigned (Int8, !Attrs.IsSPMD );
6135
6138
Constant *MayUseNestedParallelismVal = ConstantInt::getSigned (Int8, true );
6136
6139
Constant *DebugIndentionLevelVal = ConstantInt::getSigned (Int16, 0 );
6137
6140
@@ -6149,21 +6152,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
6149
6152
6150
6153
// Manifest the launch configuration in the metadata matching the kernel
6151
6154
// environment.
6152
- if (MinTeamsVal > 1 || MaxTeamsVal > 0 )
6153
- writeTeamsForKernel (T, *Kernel, MinTeamsVal, MaxTeamsVal );
6155
+ if (Attrs. MinTeams > 1 || Attrs. MaxTeams . front () > 0 )
6156
+ writeTeamsForKernel (T, *Kernel, Attrs. MinTeams , Attrs. MaxTeams . front () );
6154
6157
6155
- // For max values, < 0 means unset, == 0 means set but unknown.
6158
+ // If MaxThreads not set, select the maximum between the default workgroup
6159
+ // size and the MinThreads value.
6160
+ int32_t MaxThreadsVal = Attrs.MaxThreads .front ();
6156
6161
if (MaxThreadsVal < 0 )
6157
6162
MaxThreadsVal = std::max (
6158
- int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), MinThreadsVal );
6163
+ int32_t (getGridValue (T, Kernel).GV_Default_WG_Size ), Attrs. MinThreads );
6159
6164
6160
6165
if (MaxThreadsVal > 0 )
6161
- writeThreadBoundsForKernel (T, *Kernel, MinThreadsVal , MaxThreadsVal);
6166
+ writeThreadBoundsForKernel (T, *Kernel, Attrs. MinThreads , MaxThreadsVal);
6162
6167
6163
- Constant *MinThreads = ConstantInt::getSigned (Int32, MinThreadsVal );
6168
+ Constant *MinThreads = ConstantInt::getSigned (Int32, Attrs. MinThreads );
6164
6169
Constant *MaxThreads = ConstantInt::getSigned (Int32, MaxThreadsVal);
6165
- Constant *MinTeams = ConstantInt::getSigned (Int32, MinTeamsVal );
6166
- Constant *MaxTeams = ConstantInt::getSigned (Int32, MaxTeamsVal );
6170
+ Constant *MinTeams = ConstantInt::getSigned (Int32, Attrs. MinTeams );
6171
+ Constant *MaxTeams = ConstantInt::getSigned (Int32, Attrs. MaxTeams . front () );
6167
6172
Constant *ReductionDataSize = ConstantInt::getSigned (Int32, 0 );
6168
6173
Constant *ReductionBufferLength = ConstantInt::getSigned (Int32, 0 );
6169
6174
@@ -6730,8 +6735,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6730
6735
}
6731
6736
6732
6737
static Expected<Function *> createOutlinedFunction (
6733
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
6734
- SmallVectorImpl<Value *> &Inputs,
6738
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6739
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6740
+ StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6735
6741
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
6736
6742
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
6737
6743
SmallVector<Type *> ParameterTypes;
@@ -6798,7 +6804,7 @@ static Expected<Function *> createOutlinedFunction(
6798
6804
6799
6805
// Insert target init call in the device compilation pass.
6800
6806
if (OMPBuilder.Config .isTargetDevice ())
6801
- Builder.restoreIP (OMPBuilder.createTargetInit (Builder, /* IsSPMD */ false ));
6807
+ Builder.restoreIP (OMPBuilder.createTargetInit (Builder, DefaultAttrs ));
6802
6808
6803
6809
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6804
6810
@@ -6997,16 +7003,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6997
7003
6998
7004
static Error emitTargetOutlinedFunction (
6999
7005
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
7000
- TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
7001
- Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
7006
+ TargetRegionEntryInfo &EntryInfo,
7007
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7008
+ Function *&OutlinedFn, Constant *&OutlinedFnID,
7009
+ SmallVectorImpl<Value *> &Inputs,
7002
7010
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
7003
7011
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
7004
7012
7005
7013
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7006
- [&OMPBuilder, &Builder, &Inputs, &CBFunc,
7007
- &ArgAccessorFuncCB](StringRef EntryFnName) {
7008
- return createOutlinedFunction (OMPBuilder, Builder, EntryFnName, Inputs,
7009
- CBFunc, ArgAccessorFuncCB);
7014
+ [&](StringRef EntryFnName) {
7015
+ return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7016
+ EntryFnName, Inputs, CBFunc ,
7017
+ ArgAccessorFuncCB);
7010
7018
};
7011
7019
7012
7020
return OMPBuilder.emitTargetRegionFunction (
@@ -7302,9 +7310,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
7302
7310
7303
7311
static void
7304
7312
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7305
- OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
7306
- Constant *OutlinedFnID, ArrayRef<int32_t > NumTeams,
7307
- ArrayRef<int32_t > NumThreads, SmallVectorImpl<Value *> &Args,
7313
+ OpenMPIRBuilder::InsertPointTy AllocaIP,
7314
+ const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7315
+ Function *OutlinedFn, Constant *OutlinedFnID,
7316
+ SmallVectorImpl<Value *> &Args,
7308
7317
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
7309
7318
SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
7310
7319
bool HasNoWait = false ) {
@@ -7385,9 +7394,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7385
7394
7386
7395
SmallVector<Value *, 3 > NumTeamsC;
7387
7396
SmallVector<Value *, 3 > NumThreadsC;
7388
- for (auto V : NumTeams )
7397
+ for (auto V : DefaultAttrs. MaxTeams )
7389
7398
NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7390
- for (auto V : NumThreads )
7399
+ for (auto V : DefaultAttrs. MaxThreads )
7391
7400
NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7392
7401
7393
7402
unsigned NumTargetItems = Info.NumberOfPtrs ;
@@ -7428,7 +7437,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7428
7437
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7429
7438
const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7430
7439
InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7431
- ArrayRef< int32_t > NumTeams, ArrayRef< int32_t > NumThreads ,
7440
+ const TargetKernelDefaultAttrs &DefaultAttrs ,
7432
7441
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7433
7442
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7434
7443
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7445,16 +7454,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7445
7454
// the target region itself is generated using the callbacks CBFunc
7446
7455
// and ArgAccessorFuncCB
7447
7456
if (Error Err = emitTargetOutlinedFunction (
7448
- *this , Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID ,
7449
- Args, CBFunc, ArgAccessorFuncCB))
7457
+ *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7458
+ OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7450
7459
return Err;
7451
7460
7452
7461
// If we are not on the target device, then we need to generate code
7453
7462
// to make a remote call (offload) to the previously outlined function
7454
7463
// that represents the target region. Do that now.
7455
7464
if (!Config.isTargetDevice ())
7456
- emitTargetCall (*this , Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams ,
7457
- NumThreads , Args, GenMapInfoCB, Dependencies, HasNowait);
7465
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn ,
7466
+ OutlinedFnID , Args, GenMapInfoCB, Dependencies, HasNowait);
7458
7467
return Builder.saveIP ();
7459
7468
}
7460
7469
0 commit comments