Skip to content

[MLIR][OpenMP] Add the host_eval clause #116048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Nov 13, 2024

This patch adds the definition of a new entry block argument-defining host_eval clause. This is intended to implement the passthrough approach discussed in this RFC, for supporting host-evaluated clauses that apply to operations nested inside of omp.target.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch adds the definition of a new entry block argument-defining host_eval clause. This is intended to implement the passthrough approach discussed in this RFC, for supporting host-evaluated clauses that apply to operations nested inside of omp.target.


Full diff: https://github.com/llvm/llvm-project/pull/116048.diff

4 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+2-1)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td (+38)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+25-6)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+9)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 3d28fe7819129f..4e5d777d6c4f7f 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
 introduction of private copies of the same underlying variable defined outside
 the MLIR operation the clause is attached to. Currently, clauses with this
 property can be classified into three main categories:
-  - Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
+  - Map-like clauses: `host_eval`, `map`, `use_device_addr` and
+`use_device_ptr`.
   - Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
   - Privatization clauses: `private`.
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 855deab94b2f16..0a06c2e0335768 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -444,6 +444,44 @@ class OpenMP_HintClauseSkip<
 
 def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// Not in the spec: Clause-like structure to hold host-evaluated values.
+//===----------------------------------------------------------------------===//
+
+class OpenMP_HostEvalClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+                    extraClassDeclaration> {
+  let traits = [
+    BlockArgOpenMPOpInterface
+  ];
+
+  let arguments = (ins
+    Variadic<AnyType>:$host_eval_vars
+  );
+
+  let extraClassDeclaration = [{
+    unsigned numHostEvalBlockArgs() {
+      return getHostEvalVars().size();
+    }
+  }];
+
+  let description = [{
+    The optional `host_eval_vars` holds values defined outside of the region of
+    the `IsolatedFromAbove` operation for which a corresponding entry block
+    argument is defined. The only legal uses for these captured values are the
+    following:
+      - `num_teams` or `thread_limit` clause of an immediately nested
+      `omp.teams` operation.
+      - If the operation is the top-level `omp.target` of a target SPMD kernel:
+        - `num_threads` clause of the nested `omp.parallel` operation.
+        - Bounds and steps of the nested `omp.loop_nest` operation.
+  }];
+}
+
+def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
+
 //===----------------------------------------------------------------------===//
 // V5.2: [3.4] `if` clause
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index 8b72689dc3fd87..c68d4c81986615 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let methods = [
     // Default-implemented methods to be overriden by the corresponding clauses.
+    InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
+                    "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
+      return 0;
+    }]>,
     InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
                     "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
       return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
       return 0;
     }]>,
 
-    // Unified access methods for clause-associated entry block arguments.
+    // Unified access methods for start indices of clause-associated entry block
+    // arguments.
+    InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
+                    "unsigned", "getHostEvalBlockArgsStart", (ins), [{
+      return 0;
+    }]>,
     InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
                     "unsigned", "getInReductionBlockArgsStart", (ins), [{
-      return 0;
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
     }]>,
     InterfaceMethod<"Get start index of block arguments defined by `map`.",
                     "unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
       return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
     }]>,
 
+    // Unified access methods for clause-associated entry block arguments.
+    InterfaceMethod<"Get block arguments defined by `host_eval`.",
+                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+                    "getHostEvalBlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+          iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
+    }]>,
     InterfaceMethod<"Get block arguments defined by `in_reduction`.",
                     "::llvm::MutableArrayRef<::mlir::BlockArgument>",
                     "getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-    unsigned expectedArgs = iface.numInReductionBlockArgs() +
-        iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
-        iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
-        iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
+    unsigned expectedArgs = iface.numHostEvalBlockArgs() +
+        iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
+        iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
+        iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
+        iface.numUseDevicePtrBlockArgs();
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a1de0831653e64..b3575b1ca4018e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -502,6 +502,7 @@ struct ReductionParseArgs {
       : vars(vars), types(types), byref(byref), syms(syms) {}
 };
 struct AllRegionParseArgs {
+  std::optional<MapParseArgs> hostEvalArgs;
   std::optional<ReductionParseArgs> inReductionArgs;
   std::optional<MapParseArgs> mapArgs;
   std::optional<PrivateParseArgs> privateArgs;
@@ -628,6 +629,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
                                        AllRegionParseArgs args) {
   llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
 
+  if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
+                                 args.hostEvalArgs)))
+    return parser.emitError(parser.getCurrentLocation())
+           << "invalid `host_eval` format";
+
   if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
                                  args.inReductionArgs)))
     return parser.emitError(parser.getCurrentLocation())
@@ -789,6 +795,7 @@ struct ReductionPrintArgs {
       : vars(vars), types(types), byref(byref), syms(syms) {}
 };
 struct AllRegionPrintArgs {
+  std::optional<MapPrintArgs> hostEvalArgs;
   std::optional<ReductionPrintArgs> inReductionArgs;
   std::optional<MapPrintArgs> mapArgs;
   std::optional<PrivatePrintArgs> privateArgs;
@@ -867,6 +874,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
   MLIRContext *ctx = op->getContext();
 
+  printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
+                      args.hostEvalArgs);
   printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
                       args.inReductionArgs);
   printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),

Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation LGTM, but please address the comment fort the docs.

@skatrak skatrak force-pushed the users/skatrak/host-eval-01-clause branch from 0a47845 to d74cf35 Compare December 4, 2024 12:23
This patch adds the definition of a new entry block argument-defining
`host_eval` clause. This is intended to implement the passthrough approach
discussed in [this RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106),
for supporting host-evaluated clauses that apply to operations nested inside of
`omp.target`.
@skatrak skatrak force-pushed the users/skatrak/host-eval-01-clause branch from d74cf35 to 5efde4c Compare January 8, 2025 14:14
Copy link
Contributor

@kiranchandramohan kiranchandramohan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG.

@skatrak skatrak merged commit 8906343 into main Jan 14, 2025
8 checks passed
@skatrak skatrak deleted the users/skatrak/host-eval-01-clause branch January 14, 2025 10:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants