@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
6727
6727
return getOrCreateRuntimeFunction (M, omp::OMPRTL___kmpc_dispatch_deinit);
6728
6728
}
6729
6729
6730
+ static void emitUsed (StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
6731
+ Module &M) {
6732
+ if (List.empty ())
6733
+ return ;
6734
+
6735
+ Type *PtrTy = PointerType::get (M.getContext (), /* AddressSpace=*/ 0 );
6736
+
6737
+ // Convert List to what ConstantArray needs.
6738
+ SmallVector<Constant *, 8 > UsedArray;
6739
+ UsedArray.reserve (List.size ());
6740
+ for (auto Item : List)
6741
+ UsedArray.push_back (ConstantExpr::getPointerBitCastOrAddrSpaceCast (
6742
+ cast<Constant>(&*Item), PtrTy));
6743
+
6744
+ ArrayType *ArrTy = ArrayType::get (PtrTy, UsedArray.size ());
6745
+ auto *GV =
6746
+ new GlobalVariable (M, ArrTy, false , llvm::GlobalValue::AppendingLinkage,
6747
+ llvm::ConstantArray::get (ArrTy, UsedArray), Name);
6748
+
6749
+ GV->setSection (" llvm.metadata" );
6750
+ }
6751
+
6752
+ static void
6753
+ emitExecutionMode (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6754
+ StringRef FunctionName, OMPTgtExecModeFlags Mode,
6755
+ std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
6756
+ auto *Int8Ty = Type::getInt8Ty (Builder.getContext ());
6757
+ auto *GVMode = new llvm::GlobalVariable (
6758
+ OMPBuilder.M , Int8Ty, /* isConstant=*/ true ,
6759
+ llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get (Int8Ty, Mode),
6760
+ Twine (FunctionName, " _exec_mode" ));
6761
+ GVMode->setVisibility (llvm::GlobalVariable::ProtectedVisibility);
6762
+ LLVMCompilerUsed.emplace_back (GVMode);
6763
+ }
6764
+
6730
6765
static Expected<Function *> createOutlinedFunction (
6731
- OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
6766
+ OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
6732
6767
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
6733
6768
StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
6734
6769
OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
6758
6793
auto Func =
6759
6794
Function::Create (FuncType, GlobalValue::InternalLinkage, FuncName, M);
6760
6795
6796
+ // Forward target-cpu and target-features function attributes from the
6797
+ // original function to the new outlined function.
6798
+ Function *ParentFn = Builder.GetInsertBlock ()->getParent ();
6799
+
6800
+ auto TargetCpuAttr = ParentFn->getFnAttribute (" target-cpu" );
6801
+ if (TargetCpuAttr.isStringAttribute ())
6802
+ Func->addFnAttr (TargetCpuAttr);
6803
+
6804
+ auto TargetFeaturesAttr = ParentFn->getFnAttribute (" target-features" );
6805
+ if (TargetFeaturesAttr.isStringAttribute ())
6806
+ Func->addFnAttr (TargetFeaturesAttr);
6807
+
6808
+ if (OMPBuilder.Config .isTargetDevice ()) {
6809
+ std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
6810
+ emitExecutionMode (OMPBuilder, Builder, FuncName,
6811
+ IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
6812
+ : OMP_TGT_EXEC_MODE_GENERIC,
6813
+ LLVMCompilerUsed);
6814
+ emitUsed (" llvm.compiler.used" , LLVMCompilerUsed, OMPBuilder.M );
6815
+ }
6816
+
6761
6817
// Save insert point.
6762
6818
IRBuilder<>::InsertPointGuard IPG (Builder);
6763
6819
// If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
6798
6854
// Insert target init call in the device compilation pass.
6799
6855
if (OMPBuilder.Config .isTargetDevice ())
6800
6856
Builder.restoreIP (
6801
- OMPBuilder.createTargetInit (Builder, /* IsSPMD= */ false , DefaultAttrs));
6857
+ OMPBuilder.createTargetInit (Builder, IsSPMD, DefaultAttrs));
6802
6858
6803
6859
BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock ();
6804
6860
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
6995
7051
6996
7052
static Error emitTargetOutlinedFunction (
6997
7053
OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
6998
- TargetRegionEntryInfo &EntryInfo,
7054
+ bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
6999
7055
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7000
7056
Function *&OutlinedFn, Constant *&OutlinedFnID,
7001
7057
SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
7004
7060
7005
7061
OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
7006
7062
[&](StringRef EntryFnName) {
7007
- return createOutlinedFunction (OMPBuilder, Builder, DefaultAttrs,
7063
+ return createOutlinedFunction (OMPBuilder, Builder, IsSPMD, DefaultAttrs,
7008
7064
EntryFnName, Inputs, CBFunc,
7009
7065
ArgAccessorFuncCB);
7010
7066
};
@@ -7304,6 +7360,7 @@ static void
7304
7360
emitTargetCall (OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7305
7361
OpenMPIRBuilder::InsertPointTy AllocaIP,
7306
7362
const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
7363
+ const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
7307
7364
Function *OutlinedFn, Constant *OutlinedFnID,
7308
7365
SmallVectorImpl<Value *> &Args,
7309
7366
OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7385
7442
/* ForEndCall=*/ false );
7386
7443
7387
7444
SmallVector<Value *, 3 > NumTeamsC;
7445
+ for (auto [DefaultVal, RuntimeVal] :
7446
+ zip_equal (DefaultAttrs.MaxTeams , RuntimeAttrs.MaxTeams ))
7447
+ NumTeamsC.push_back (RuntimeVal ? RuntimeVal : Builder.getInt32 (DefaultVal));
7448
+
7449
+ // Calculate number of threads: 0 if no clauses specified, otherwise it is the
7450
+ // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
7451
+ auto InitMaxThreadsClause = [&Builder](Value *Clause) {
7452
+ if (Clause)
7453
+ Clause = Builder.CreateIntCast (Clause, Builder.getInt32Ty (),
7454
+ /* isSigned=*/ false );
7455
+ return Clause;
7456
+ };
7457
+ auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
7458
+ if (Clause)
7459
+ Result = Result
7460
+ ? Builder.CreateSelect (Builder.CreateICmpULT (Result, Clause),
7461
+ Result, Clause)
7462
+ : Clause;
7463
+ };
7464
+
7465
+ // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
7466
+ // the NUM_THREADS clause is overriden by THREAD_LIMIT.
7388
7467
SmallVector<Value *, 3 > NumThreadsC;
7389
- for (auto V : DefaultAttrs.MaxTeams )
7390
- NumTeamsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7391
- for (auto V : DefaultAttrs.MaxThreads )
7392
- NumThreadsC.push_back (llvm::ConstantInt::get (Builder.getInt32Ty (), V));
7468
+ Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit .size () == 1
7469
+ ? InitMaxThreadsClause (RuntimeAttrs.MaxThreads )
7470
+ : nullptr ;
7471
+
7472
+ for (auto [TeamsVal, TargetVal] : llvm::zip_equal (
7473
+ RuntimeAttrs.TeamsThreadLimit , RuntimeAttrs.TargetThreadLimit )) {
7474
+ Value *TeamsThreadLimitClause = InitMaxThreadsClause (TeamsVal);
7475
+ Value *NumThreads = InitMaxThreadsClause (TargetVal);
7476
+
7477
+ CombineMaxThreadsClauses (TeamsThreadLimitClause, NumThreads);
7478
+ CombineMaxThreadsClauses (MaxThreadsClause, NumThreads);
7479
+
7480
+ NumThreadsC.push_back (NumThreads ? NumThreads : Builder.getInt32 (0 ));
7481
+ }
7393
7482
7394
7483
unsigned NumTargetItems = Info.NumberOfPtrs ;
7395
7484
// TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7398
7487
Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr (SrcLocStrSize);
7399
7488
Value *RTLoc = OMPBuilder.getOrCreateIdent (SrcLocStr, SrcLocStrSize,
7400
7489
llvm::omp::IdentFlag (0 ), 0 );
7401
- // TODO: Use correct NumIterations
7402
- Value *NumIterations = Builder.getInt64 (0 );
7490
+
7491
+ Value *TripCount = RuntimeAttrs.LoopTripCount
7492
+ ? Builder.CreateIntCast (RuntimeAttrs.LoopTripCount ,
7493
+ Builder.getInt64Ty (),
7494
+ /* isSigned=*/ false )
7495
+ : Builder.getInt64 (0 );
7496
+
7403
7497
// TODO: Use correct DynCGGroupMem
7404
7498
Value *DynCGGroupMem = Builder.getInt32 (0 );
7405
7499
7406
- KArgs = OpenMPIRBuilder::TargetKernelArgs (
7407
- NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
7408
- DynCGGroupMem, HasNoWait);
7500
+ KArgs = OpenMPIRBuilder::TargetKernelArgs (NumTargetItems, RTArgs, TripCount,
7501
+ NumTeamsC, NumThreadsC,
7502
+ DynCGGroupMem, HasNoWait);
7409
7503
7410
7504
// The presence of certain clauses on the target directive require the
7411
7505
// explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
7427
7521
}
7428
7522
7429
7523
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget (
7430
- const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
7431
- InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
7524
+ const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
7525
+ InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
7526
+ TargetRegionEntryInfo &EntryInfo,
7432
7527
const TargetKernelDefaultAttrs &DefaultAttrs,
7528
+ const TargetKernelRuntimeAttrs &RuntimeAttrs,
7433
7529
SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
7434
7530
OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
7435
7531
OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
7436
7532
SmallVector<DependData> Dependencies, bool HasNowait) {
7533
+ assert ((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
7534
+ " trip count not expected if IsSPMD=false" );
7437
7535
7438
7536
if (!updateToLocation (Loc))
7439
7537
return InsertPointTy ();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
7446
7544
// the target region itself is generated using the callbacks CBFunc
7447
7545
// and ArgAccessorFuncCB
7448
7546
if (Error Err = emitTargetOutlinedFunction (
7449
- *this , Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn ,
7450
- OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7547
+ *this , Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs ,
7548
+ OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
7451
7549
return Err;
7452
7550
7453
7551
// If we are not on the target device, then we need to generate code
7454
7552
// to make a remote call (offload) to the previously outlined function
7455
7553
// that represents the target region. Do that now.
7456
7554
if (!Config.isTargetDevice ())
7457
- emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, OutlinedFn,
7458
- OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
7555
+ emitTargetCall (*this , Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
7556
+ OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
7557
+ HasNowait);
7459
7558
return Builder.saveIP ();
7460
7559
}
7461
7560
0 commit comments