Skip to content

[OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode #116051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Nov 13, 2024

This patch introduces a TargetKernelRuntimeAttrs structure to hold host-evaluated num_teams, thread_limit, num_threads and trip count values passed to the runtime kernel offloading call.

Additionally, kernel type information is used to influence target device code generation and the IsSPMD flag is replaced by ExecFlags, which provide more granularity.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-clang-codegen
@llvm/pr-subscribers-clang
@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a TargetKernelRuntimeAttrs structure to hold host- evaluated num_teams, thread_limit, num_threads and trip count values passed to the runtime kernel offloading call.

Additionally, createTarget is extended to take an IsSPMD flag, used to influence target device code generation.


Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+118-19)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+271-10)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-4)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+  /// must be defined in the host prior to the call to the kernel launch OpenMP
+  /// RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is a
+    /// target SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the target SPMD kernel or null if it is a
+    /// generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
+  /// \param IsSPMD whether it is a target SPMD kernel.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
   InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
+      const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+                     Module &M) {
+  if (List.empty())
+    return;
+
+  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.reserve(List.size());
+  for (auto Item : List)
+    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*Item), PtrTy));
+
+  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+  auto *GV =
+      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
+                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+  auto *GVMode = new llvm::GlobalVariable(
+      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+      Twine(FunctionName, "_exec_mode"));
+  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+  LLVMCompilerUsed.emplace_back(GVMode);
+}
+
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  // Forward target-cpu and target-features function attributes from the
+  // original function to the new outlined function.
+  Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+  auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+  if (TargetCpuAttr.isStringAttribute())
+    Func->addFnAttr(TargetCpuAttr);
+
+  auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+  if (TargetFeaturesAttr.isStringAttribute())
+    Func->addFnAttr(TargetFeaturesAttr);
+
+  if (OMPBuilder.Config.isTargetDevice()) {
+    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+    emitExecutionMode(OMPBuilder, Builder, FuncName,
+                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+                             : OMP_TGT_EXEC_MODE_GENERIC,
+                      LLVMCompilerUsed);
+    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(
-        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+        OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo,
+    bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     Function *&OutlinedFn, Constant *&OutlinedFnID,
     SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
       [&](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+        return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
                                       EntryFnName, Inputs, CBFunc,
                                       ArgAccessorFuncCB);
       };
@@ -7304,6 +7360,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
-    const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+    const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+    InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+    TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
+  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
-          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+          OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
   OMPBuilder.setConfig(Config);
   F->setName("func");
+  F->addFnAttr("target-cpu", "x86-64");
+  F->addFnAttr("target-features", "+mmx,+sse");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName2 = OutlinedFunc->getName();
   EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   OMPBuilder.initialize();
 
   F->setName("func");
+  F->addFnAttr("target-cpu", "gfx90a");
+  F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   Function *OutlinedFn = TargetStore->getFunction();
   EXPECT_NE(F, OutlinedFn);
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+       ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch introduces a TargetKernelRuntimeAttrs structure to hold host- evaluated num_teams, thread_limit, num_threads and trip count values passed to the runtime kernel offloading call.

Additionally, createTarget is extended to take an IsSPMD flag, used to influence target device code generation.


Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff

4 Files Affected:

  • (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-1)
  • (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+118-19)
  • (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+271-10)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-4)
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+  /// must be defined in the host prior to the call to the kernel launch OpenMP
+  /// RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is a
+    /// target SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the target SPMD kernel or null if it is a
+    /// generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
+  /// \param IsSPMD whether it is a target SPMD kernel.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
   InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
+      const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+                     Module &M) {
+  if (List.empty())
+    return;
+
+  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.reserve(List.size());
+  for (auto Item : List)
+    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*Item), PtrTy));
+
+  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+  auto *GV =
+      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
+                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+  auto *GVMode = new llvm::GlobalVariable(
+      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+      Twine(FunctionName, "_exec_mode"));
+  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+  LLVMCompilerUsed.emplace_back(GVMode);
+}
+
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  // Forward target-cpu and target-features function attributes from the
+  // original function to the new outlined function.
+  Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+  auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+  if (TargetCpuAttr.isStringAttribute())
+    Func->addFnAttr(TargetCpuAttr);
+
+  auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+  if (TargetFeaturesAttr.isStringAttribute())
+    Func->addFnAttr(TargetFeaturesAttr);
+
+  if (OMPBuilder.Config.isTargetDevice()) {
+    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+    emitExecutionMode(OMPBuilder, Builder, FuncName,
+                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+                             : OMP_TGT_EXEC_MODE_GENERIC,
+                      LLVMCompilerUsed);
+    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(
-        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+        OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo,
+    bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     Function *&OutlinedFn, Constant *&OutlinedFnID,
     SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
       [&](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+        return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
                                       EntryFnName, Inputs, CBFunc,
                                       ArgAccessorFuncCB);
       };
@@ -7304,6 +7360,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
-    const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+    const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+    InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+    TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
+  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
-          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+          OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
   OMPBuilder.setConfig(Config);
   F->setName("func");
+  F->addFnAttr("target-cpu", "x86-64");
+  F->addFnAttr("target-features", "+mmx,+sse");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName2 = OutlinedFunc->getName();
   EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   OMPBuilder.initialize();
 
   F->setName("func");
+  F->addFnAttr("target-cpu", "gfx90a");
+  F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   Function *OutlinedFn = TargetStore->getFunction();
   EXPECT_NE(F, OutlinedFn);
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+       ...
[truncated]

@skatrak
Copy link
Member Author

skatrak commented Nov 13, 2024

Buildbot failure seems to be some temporary issue unrelated to the PR.

@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from cc5c5cc to e2b3ac4 Compare November 27, 2024 12:26
Copy link
Member Author

@skatrak skatrak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @jdoerfert for the review. Your comments should be addressed now.

@skatrak skatrak force-pushed the users/skatrak/host-eval-03-ompirbuilder-struct branch from 1fcfe48 to e3cdc93 Compare December 4, 2024 13:41
@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from e2b3ac4 to b1e4eb5 Compare December 4, 2024 14:16
@llvmbot llvmbot added clang Clang issues not falling into any other category clang:codegen IR generation bugs: mangling, exceptions, etc. labels Dec 4, 2024
@skatrak
Copy link
Member Author

skatrak commented Dec 4, 2024

@jdoerfert, can you check whether your concerns have been addressed?

Copy link
Contributor

@DominikAdamski DominikAdamski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@skatrak skatrak force-pushed the users/skatrak/host-eval-03-ompirbuilder-struct branch from e3cdc93 to 45c6667 Compare January 8, 2025 16:35
@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from b1e4eb5 to 76b2b9f Compare January 9, 2025 11:53
Copy link

github-actions bot commented Jan 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from 76b2b9f to 47a6495 Compare January 9, 2025 12:01
@skatrak
Copy link
Member Author

skatrak commented Jan 10, 2025

The PR stack should be almost ready to be merged in the next few days (buildbot failures are unrelated). If there are any remaining blockers from your side, let me know @jdoerfert.

@skatrak skatrak force-pushed the users/skatrak/host-eval-03-ompirbuilder-struct branch from 45c6667 to 219d430 Compare January 14, 2025 10:28
Base automatically changed from users/skatrak/host-eval-03-ompirbuilder-struct to main January 14, 2025 11:08
@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from 47a6495 to 1f5cd91 Compare January 14, 2025 11:14
…mode

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.

Additionally, kernel type information is used to influence target device code
generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more
granularity.
@skatrak skatrak force-pushed the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch from 1f5cd91 to 0c19f71 Compare January 14, 2025 11:25
@skatrak skatrak merged commit fabc443 into main Jan 14, 2025
8 checks passed
@skatrak skatrak deleted the users/skatrak/host-eval-04-ompirbuilder-teams-threads branch January 14, 2025 12:34
@llvm-ci
Copy link
Collaborator

llvm-ci commented Jan 14, 2025

LLVM Buildbot has detected a new failure on builder llvm-clang-x86_64-expensive-checks-debian running on gribozavr4 while building clang,llvm,mlir at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/16/builds/11964

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'LLVM :: tools/llvm-gsymutil/ARM_AArch64/macho-merged-funcs-dwarf.yaml' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
Input file: /b/1/llvm-clang-x86_64-expensive-checks-debian/build/test/tools/llvm-gsymutil/ARM_AArch64/Output/macho-merged-funcs-dwarf.yaml.tmp.dSYM
Output file (aarch64): /b/1/llvm-clang-x86_64-expensive-checks-debian/build/test/tools/llvm-gsymutil/ARM_AArch64/Output/macho-merged-funcs-dwarf.yaml.tmp.default.gSYM
Loaded 3 functions from DWARF.
Loaded 3 functions from symbol table.
warning: same address range contains different debug info. Removing:
[0x0000000000000248 - 0x0000000000000270): Name=0x00000001
addr=0x0000000000000248, file=  1, line=  5
addr=0x0000000000000254, file=  1, line=  7
addr=0x0000000000000258, file=  1, line=  9
addr=0x000000000000025c, file=  1, line=  8
addr=0x0000000000000260, file=  1, line= 11
addr=0x0000000000000264, file=  1, line= 10
addr=0x0000000000000268, file=  1, line=  6


In favor of this one:
[0x0000000000000248 - 0x0000000000000270): Name=0x00000047
addr=0x0000000000000248, file=  3, line=  5
addr=0x0000000000000254, file=  3, line=  7
addr=0x0000000000000258, file=  3, line=  9
addr=0x000000000000025c, file=  3, line=  8
addr=0x0000000000000260, file=  3, line= 11
addr=0x0000000000000264, file=  3, line= 10
addr=0x0000000000000268, file=  3, line=  6


warning: same address range contains different debug info. Removing:
[0x0000000000000248 - 0x0000000000000270): Name=0x00000047
addr=0x0000000000000248, file=  3, line=  5
addr=0x0000000000000254, file=  3, line=  7
addr=0x0000000000000258, file=  3, line=  9
addr=0x000000000000025c, file=  3, line=  8
addr=0x0000000000000260, file=  3, line= 11
addr=0x0000000000000264, file=  3, line= 10
addr=0x0000000000000268, file=  3, line=  6


In favor of this one:
[0x0000000000000248 - 0x0000000000000270): Name=0x00000030
addr=0x0000000000000248, file=  2, line=  5
addr=0x0000000000000254, file=  2, line=  7
addr=0x0000000000000258, file=  2, line=  9
addr=0x000000000000025c, file=  2, line=  8
addr=0x0000000000000260, file=  2, line= 11
addr=0x0000000000000264, file=  2, line= 10
...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:codegen IR generation bugs: mangling, exceptions, etc. clang:openmp OpenMP related changes to Clang clang Clang issues not falling into any other category flang:openmp mlir:llvm mlir:openmp mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants