Skip to content

[MLIR][OpenMP] Add host_eval clause to omp.target #116049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Nov 13, 2024

This patch adds the host_eval clause to the omp.target operation. Additionally, it updates its op verifier to make sure all uses of block arguments defined by this clause fall within one of the few cases where they are allowed.

MLIR to LLVM IR translation fails on translation of this clause with a not-yet-implemented error.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-flang-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Sergio Afonso (skatrak)

Changes

This patch adds the host_eval clause to the omp.target operation. Additionally, it updates its op verifier to make sure all uses of block arguments defined by this clause fall within one of the few cases where they are allowed.

MLIR to LLVM IR translation fails on translation of this clause with a not-yet-implemented error.


Patch is 20.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116049.diff

7 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+26-7)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+163-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+5)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+69-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+37-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+14)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 4e5d777d6c4f7f..e0dd3f598e84b6 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -523,3 +523,58 @@ omp.parallel ... {
   omp.terminator
 } {omp.composite}
 ```
+
+## Host-Evaluated Clauses in Target Regions
+
+The `omp.target` operation, which represents the OpenMP `target` construct, is
+marked with the `IsolatedFromAbove` trait. This means that, inside of its
+region, no MLIR values defined outside of the op itself can be used. This is
+consistent with the OpenMP specification of the `target` construct, which
+mandates that all host device values used inside of the `target` region must
+either be privatized (data-sharing) or mapped (data-mapping).
+
+Normally, clauses applied to a construct are evaluated before entering that
+construct. Further, in some cases, the OpenMP specification stipulates that
+clauses be evaluated _on the host device_ on entry to a parent `target`
+construct. In particular, the `num_teams` and `thread_limit` clauses of the
+`teams` construct must be evaluated on the host device if it's nested inside or
+combined with a `target` construct.
+
+Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
+the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
+`target teams distribute parallel {do,for}` in OpenMP), which requires
+specifying in advance what the total trip count of the loop is. Consequently, it
+is also beneficial to evaluate the trip count on the host device prior to the
+kernel launch.
+
+These host-evaluated values in MLIR would need to be placed outside of the
+`omp.target` region and also attached to the corresponding nested operations,
+which is not possible because of the `IsolatedFromAbove` trait. The solution
+implemented to address this problem has been to introduce the `host_eval`
+argument to the `omp.target` operation. It works similarly to a `map` clause,
+but its only intended use is to forward host-evaluated values to their
+corresponding operation inside of the region. Any uses outside of the previously
+described result in a verifier error.
+
+```mlir
+// Initialize %0, %1, %2, %3...
+omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
+  omp.teams num_teams(to %nt : i32) {
+    omp.parallel {
+      omp.distribute {
+        omp.wsloop {
+          omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+            // ...
+            omp.yield
+          }
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  omp.terminator
+}
+```
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a0da3db124d1f4..a99da1f0294d08 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1166,9 +1166,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
-    OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
-    OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
-    OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+    OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
+    OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
+    OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
+    OpenMP_PrivateClause, OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "target construct";
   let description = [{
@@ -1186,16 +1187,34 @@ def TargetOp : OpenMP_Op<"target", traits = [
 
   let extraClassDeclaration = [{
     unsigned numMapBlockArgs() { return getMapVars().size(); }
+
+    /// Returns the innermost OpenMP dialect operation captured by this target
+    /// construct. For an operation to be detected as captured, it must be
+    /// inside a (possibly multi-level) nest of OpenMP dialect operation's
+    /// regions where none of these levels contain other operations considered
+    /// not-allowed for these purposes (i.e. only terminator operations are
+    /// allowed from the OpenMP dialect, and other dialect's operations are
+    /// allowed as long as they don't have a memory write effect).
+    ///
+    /// If there are omp.loop_nest operations in the sequence of nested
+    /// operations, the top level one will be the one captured.
+    Operation *getInnermostCapturedOmpOp();
+
+    /// Checks whether this target region represents the MLIR equivalent to a
+    /// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
+    bool isTargetSPMDLoop();
   }] # clausesExtraClassDeclaration;
 
   let assemblyFormat = clausesAssemblyFormat # [{
-    custom<InReductionMapPrivateRegion>(
-        $region, $in_reduction_vars, type($in_reduction_vars),
-        $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+    custom<HostEvalInReductionMapPrivateRegion>(
+        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
+        type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
+        $map_vars, type($map_vars), $private_vars, type($private_vars),
+        $private_syms) attr-dict
   }];
 
   let hasVerifier = 1;
+  let hasRegionVerifier = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b3575b1ca4018e..e42ed9dc104981 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -672,8 +672,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
   return parser.parseRegion(region, entryBlockArgs);
 }
 
-static ParseResult parseInReductionMapPrivateRegion(
+static ParseResult parseHostEvalInReductionMapPrivateRegion(
     OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
+    SmallVectorImpl<Type> &hostEvalTypes,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -682,6 +684,7 @@ static ParseResult parseInReductionMapPrivateRegion(
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   AllRegionParseArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -896,12 +899,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
-static void printInReductionMapPrivateRegion(
-    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
+static void printHostEvalInReductionMapPrivateRegion(
+    OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
+    TypeRange hostEvalTypes, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
     ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -1685,7 +1690,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
                   makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
+                  clauses.device, clauses.hasDeviceAddrVars,
+                  clauses.hostEvalVars, clauses.ifExpr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1699,6 +1705,159 @@ LogicalResult TargetOp::verify() {
                                   : verifyMapClause(*this, getMapVars());
 }
 
+LogicalResult TargetOp::verifyRegions() {
+  auto teamsOps = getOps<TeamsOp>();
+  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
+    return emitError("target containing multiple 'omp.teams' nested ops");
+
+  // Check that host_eval values are only used in legal ways.
+  bool isTargetSPMD = isTargetSPMDLoop();
+  for (Value hostEvalArg :
+       cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
+    for (Operation *user : hostEvalArg.getUsers()) {
+      if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
+        if (llvm::is_contained({teamsOp.getNumTeamsLower(),
+                                teamsOp.getNumTeamsUpper(),
+                                teamsOp.getThreadLimit()},
+                               hostEvalArg))
+          continue;
+
+        return emitOpError() << "host_eval argument only legal as 'num_teams' "
+                                "and 'thread_limit' in 'omp.teams'";
+      }
+      if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
+        if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as 'num_threads' in "
+                  "'omp.parallel' when representing target SPMD";
+      }
+      if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
+        if (isTargetSPMD &&
+            (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as loop bounds and steps in "
+                  "'omp.loop_nest' when representing target SPMD";
+      }
+
+      return emitOpError() << "host_eval argument illegal use in '"
+                           << user->getName() << "' operation";
+    }
+  }
+  return success();
+}
+
+/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
+/// effects, but don't include a memory write effect.
+static bool siblingAllowedInCapture(Operation *op) {
+  if (!op)
+    return false;
+
+  bool isOmpDialect =
+      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
+      op->getDialect();
+
+  if (isOmpDialect)
+    return op->hasTrait<OpTrait::IsTerminator>();
+
+  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
+    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
+    memOp.getEffects(effects);
+    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+      return isa<MemoryEffects::Write>(effect.getEffect()) &&
+             isa<SideEffects::AutomaticAllocationScopeResource>(
+                 effect.getResource());
+    });
+  }
+  return true;
+}
+
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  Dialect *ompDialect = (*this)->getDialect();
+  Operation *capturedOp = nullptr;
+
+  // Process in pre-order to check operations from outermost to innermost,
+  // ensuring we only enter the region of an operation if it meets the criteria
+  // for being captured. We stop the exploration of nested operations as soon as
+  // we process a region holding no operations to be captured.
+  walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == *this)
+      return WalkResult::advance();
+
+    // Ignore operations of other dialects or omp operations with no regions,
+    // because these will only be checked if they are siblings of an omp
+    // operation that can potentially be captured.
+    bool isOmpDialect = op->getDialect() == ompDialect;
+    bool hasRegions = op->getNumRegions() > 0;
+    if (!isOmpDialect || !hasRegions)
+      return WalkResult::skip();
+
+    // Don't capture this op if it has a not-allowed sibling, and stop recursing
+    // into nested operations.
+    for (Operation &sibling : op->getParentRegion()->getOps())
+      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+        return WalkResult::interrupt();
+
+    // Don't continue capturing nested operations if we reach an omp.loop_nest.
+    // Otherwise, process the contents of this operation.
+    capturedOp = op;
+    return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
+                                     : WalkResult::advance();
+  });
+
+  return capturedOp;
+}
+
+bool TargetOp::isTargetSPMDLoop() {
+  // The expected MLIR representation for a target SPMD loop is:
+  // omp.target {
+  //   omp.teams {
+  //     omp.parallel {
+  //       omp.distribute {
+  //         omp.wsloop {
+  //           omp.loop_nest ... { ... }
+  //         } {omp.composite}
+  //       } {omp.composite}
+  //       omp.terminator
+  //     } {omp.composite}
+  //     omp.terminator
+  //   }
+  //   omp.terminator
+  // }
+
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  if (!isa_and_present<LoopNestOp>(capturedOp))
+    return false;
+
+  Operation *workshareOp = capturedOp->getParentOp();
+
+  // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
+  if (isa_and_present<SimdOp>(workshareOp))
+    workshareOp = workshareOp->getParentOp();
+
+  if (!isa_and_present<WsloopOp>(workshareOp))
+    return false;
+
+  Operation *distributeOp = workshareOp->getParentOp();
+  if (!isa_and_present<DistributeOp>(distributeOp))
+    return false;
+
+  Operation *parallelOp = distributeOp->getParentOp();
+  if (!isa_and_present<ParallelOp>(parallelOp))
+    return false;
+
+  Operation *teamsOp = parallelOp->getParentOp();
+  if (!isa_and_present<TeamsOp>(teamsOp))
+    return false;
+
+  return teamsOp->getParentOp() == (*this);
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..cbcbeea4ab9225 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -174,6 +174,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
+  auto checkHostEval = [&todo](auto op, LogicalResult &result) {
+    if (!op.getHostEvalVars().empty())
+      result = todo("host_eval");
+  };
   auto checkIf = [&todo](auto op, LogicalResult &result) {
     if (op.getIfExpr())
       result = todo("if");
@@ -291,6 +295,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkAllocate(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
+        checkHostEval(op, result);
         checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index aa41eea44f3ef4..216b3d6d806906 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2138,11 +2138,79 @@ func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_multiple_teams() {
+  // expected-error @below {{target containing multiple 'omp.teams' nested ops}}
+  omp.target {
+    omp.teams {
+      omp.terminator
+    }
+    omp.teams {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval(%x : !llvm.ptr) {
+  // expected-error @below {{op host_eval argument illegal use in 'llvm.load' operation}}
+  omp.target host_eval(%x -> %arg0 : !llvm.ptr) {
+    %0 = llvm.load %arg0 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_teams(%x : i1) {
+  // expected-error @below {{op host_eval argument only legal as 'num_teams' and 'thread_limit' in 'omp.teams'}}
+  omp.target host_eval(%x -> %arg0 : i1) {
+    omp.teams if(%arg0) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_parallel(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as 'num_threads' in 'omp.parallel' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_loop(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.wsloop {
+      omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4f5cc696cada81..5a1b184fd3b94a 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -770,7 +770,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2750,6 +2750,42 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+func.func @omp_target_host_eval(%x : i32) {
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
+  // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams num_teams(to %arg0 : i32) thread_limit(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.distribute {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) {
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams {
+      omp.parallel num_threads(%arg0 : i32) {
+        omp.distribute {
+          omp.wsloop {
+            omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+              omp.yield
+            }
+          } {omp.composite}
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: omp_loop
 func.func @omp_loop(%lb : index, %ub : index, %step : index) {
   // CHECK: o...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch adds the host_eval clause to the omp.target operation. Additionally, it updates its op verifier to make sure all uses of block arguments defined by this clause fall within one of the few cases where they are allowed.

MLIR to LLVM IR translation fails on translation of this clause with a not-yet-implemented error.


Patch is 20.92 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116049.diff

7 Files Affected:

  • (modified) mlir/docs/Dialects/OpenMPDialect/_index.md (+55)
  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+26-7)
  • (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+163-4)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+5)
  • (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+69-1)
  • (modified) mlir/test/Dialect/OpenMP/ops.mlir (+37-1)
  • (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+14)
diff --git a/mlir/docs/Dialects/OpenMPDialect/_index.md b/mlir/docs/Dialects/OpenMPDialect/_index.md
index 4e5d777d6c4f7f..e0dd3f598e84b6 100644
--- a/mlir/docs/Dialects/OpenMPDialect/_index.md
+++ b/mlir/docs/Dialects/OpenMPDialect/_index.md
@@ -523,3 +523,58 @@ omp.parallel ... {
   omp.terminator
 } {omp.composite}
 ```
+
+## Host-Evaluated Clauses in Target Regions
+
+The `omp.target` operation, which represents the OpenMP `target` construct, is
+marked with the `IsolatedFromAbove` trait. This means that, inside of its
+region, no MLIR values defined outside of the op itself can be used. This is
+consistent with the OpenMP specification of the `target` construct, which
+mandates that all host device values used inside of the `target` region must
+either be privatized (data-sharing) or mapped (data-mapping).
+
+Normally, clauses applied to a construct are evaluated before entering that
+construct. Further, in some cases, the OpenMP specification stipulates that
+clauses be evaluated _on the host device_ on entry to a parent `target`
+construct. In particular, the `num_teams` and `thread_limit` clauses of the
+`teams` construct must be evaluated on the host device if it's nested inside or
+combined with a `target` construct.
+
+Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
+the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
+`target teams distribute parallel {do,for}` in OpenMP), which requires
+specifying in advance what the total trip count of the loop is. Consequently, it
+is also beneficial to evaluate the trip count on the host device prior to the
+kernel launch.
+
+These host-evaluated values in MLIR would need to be placed outside of the
+`omp.target` region and also attached to the corresponding nested operations,
+which is not possible because of the `IsolatedFromAbove` trait. The solution
+implemented to address this problem has been to introduce the `host_eval`
+argument to the `omp.target` operation. It works similarly to a `map` clause,
+but its only intended use is to forward host-evaluated values to their
+corresponding operation inside of the region. Any uses outside of the previously
+described result in a verifier error.
+
+```mlir
+// Initialize %0, %1, %2, %3...
+omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
+  omp.teams num_teams(to %nt : i32) {
+    omp.parallel {
+      omp.distribute {
+        omp.wsloop {
+          omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
+            // ...
+            omp.yield
+          }
+          omp.terminator
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    } {omp.composite}
+    omp.terminator
+  }
+  omp.terminator
+}
+```
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index a0da3db124d1f4..a99da1f0294d08 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -1166,9 +1166,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
   ], clauses = [
     // TODO: Complete clause list (defaultmap, uses_allocators).
     OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
-    OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
-    OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
-    OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
+    OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
+    OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
+    OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
+    OpenMP_PrivateClause, OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "target construct";
   let description = [{
@@ -1186,16 +1187,34 @@ def TargetOp : OpenMP_Op<"target", traits = [
 
   let extraClassDeclaration = [{
     unsigned numMapBlockArgs() { return getMapVars().size(); }
+
+    /// Returns the innermost OpenMP dialect operation captured by this target
+    /// construct. For an operation to be detected as captured, it must be
+    /// inside a (possibly multi-level) nest of OpenMP dialect operation's
+    /// regions where none of these levels contain other operations considered
+    /// not-allowed for these purposes (i.e. only terminator operations are
+    /// allowed from the OpenMP dialect, and other dialect's operations are
+    /// allowed as long as they don't have a memory write effect).
+    ///
+    /// If there are omp.loop_nest operations in the sequence of nested
+    /// operations, the top level one will be the one captured.
+    Operation *getInnermostCapturedOmpOp();
+
+    /// Checks whether this target region represents the MLIR equivalent to a
+    /// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
+    bool isTargetSPMDLoop();
   }] # clausesExtraClassDeclaration;
 
   let assemblyFormat = clausesAssemblyFormat # [{
-    custom<InReductionMapPrivateRegion>(
-        $region, $in_reduction_vars, type($in_reduction_vars),
-        $in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
-        $private_vars, type($private_vars), $private_syms) attr-dict
+    custom<HostEvalInReductionMapPrivateRegion>(
+        $region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
+        type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
+        $map_vars, type($map_vars), $private_vars, type($private_vars),
+        $private_syms) attr-dict
   }];
 
   let hasVerifier = 1;
+  let hasRegionVerifier = 1;
 }
 
 
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b3575b1ca4018e..e42ed9dc104981 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -672,8 +672,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
   return parser.parseRegion(region, entryBlockArgs);
 }
 
-static ParseResult parseInReductionMapPrivateRegion(
+static ParseResult parseHostEvalInReductionMapPrivateRegion(
     OpAsmParser &parser, Region &region,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
+    SmallVectorImpl<Type> &hostEvalTypes,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
     SmallVectorImpl<Type> &inReductionTypes,
     DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -682,6 +684,7 @@ static ParseResult parseInReductionMapPrivateRegion(
     llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
     llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
   AllRegionParseArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -896,12 +899,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
   p.printRegion(region, /*printEntryBlockArgs=*/false);
 }
 
-static void printInReductionMapPrivateRegion(
-    OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
+static void printHostEvalInReductionMapPrivateRegion(
+    OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
+    TypeRange hostEvalTypes, ValueRange inReductionVars,
     TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
     ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
     ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
   AllRegionPrintArgs args;
+  args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
   args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
                                inReductionByref, inReductionSyms);
   args.mapArgs.emplace(mapVars, mapTypes);
@@ -1685,7 +1690,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
   // inReductionByref, inReductionSyms.
   TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
                   makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
-                  clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
+                  clauses.device, clauses.hasDeviceAddrVars,
+                  clauses.hostEvalVars, clauses.ifExpr,
                   /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1699,6 +1705,159 @@ LogicalResult TargetOp::verify() {
                                   : verifyMapClause(*this, getMapVars());
 }
 
+LogicalResult TargetOp::verifyRegions() {
+  auto teamsOps = getOps<TeamsOp>();
+  if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
+    return emitError("target containing multiple 'omp.teams' nested ops");
+
+  // Check that host_eval values are only used in legal ways.
+  bool isTargetSPMD = isTargetSPMDLoop();
+  for (Value hostEvalArg :
+       cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
+    for (Operation *user : hostEvalArg.getUsers()) {
+      if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
+        if (llvm::is_contained({teamsOp.getNumTeamsLower(),
+                                teamsOp.getNumTeamsUpper(),
+                                teamsOp.getThreadLimit()},
+                               hostEvalArg))
+          continue;
+
+        return emitOpError() << "host_eval argument only legal as 'num_teams' "
+                                "and 'thread_limit' in 'omp.teams'";
+      }
+      if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
+        if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as 'num_threads' in "
+                  "'omp.parallel' when representing target SPMD";
+      }
+      if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
+        if (isTargetSPMD &&
+            (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
+             llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
+          continue;
+
+        return emitOpError()
+               << "host_eval argument only legal as loop bounds and steps in "
+                  "'omp.loop_nest' when representing target SPMD";
+      }
+
+      return emitOpError() << "host_eval argument illegal use in '"
+                           << user->getName() << "' operation";
+    }
+  }
+  return success();
+}
+
+/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
+/// effects, but don't include a memory write effect.
+static bool siblingAllowedInCapture(Operation *op) {
+  if (!op)
+    return false;
+
+  bool isOmpDialect =
+      op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
+      op->getDialect();
+
+  if (isOmpDialect)
+    return op->hasTrait<OpTrait::IsTerminator>();
+
+  if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
+    SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
+    memOp.getEffects(effects);
+    return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
+      return isa<MemoryEffects::Write>(effect.getEffect()) &&
+             isa<SideEffects::AutomaticAllocationScopeResource>(
+                 effect.getResource());
+    });
+  }
+  return true;
+}
+
+Operation *TargetOp::getInnermostCapturedOmpOp() {
+  Dialect *ompDialect = (*this)->getDialect();
+  Operation *capturedOp = nullptr;
+
+  // Process in pre-order to check operations from outermost to innermost,
+  // ensuring we only enter the region of an operation if it meets the criteria
+  // for being captured. We stop the exploration of nested operations as soon as
+  // we process a region holding no operations to be captured.
+  walk<WalkOrder::PreOrder>([&](Operation *op) {
+    if (op == *this)
+      return WalkResult::advance();
+
+    // Ignore operations of other dialects or omp operations with no regions,
+    // because these will only be checked if they are siblings of an omp
+    // operation that can potentially be captured.
+    bool isOmpDialect = op->getDialect() == ompDialect;
+    bool hasRegions = op->getNumRegions() > 0;
+    if (!isOmpDialect || !hasRegions)
+      return WalkResult::skip();
+
+    // Don't capture this op if it has a not-allowed sibling, and stop recursing
+    // into nested operations.
+    for (Operation &sibling : op->getParentRegion()->getOps())
+      if (&sibling != op && !siblingAllowedInCapture(&sibling))
+        return WalkResult::interrupt();
+
+    // Don't continue capturing nested operations if we reach an omp.loop_nest.
+    // Otherwise, process the contents of this operation.
+    capturedOp = op;
+    return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
+                                     : WalkResult::advance();
+  });
+
+  return capturedOp;
+}
+
+bool TargetOp::isTargetSPMDLoop() {
+  // The expected MLIR representation for a target SPMD loop is:
+  // omp.target {
+  //   omp.teams {
+  //     omp.parallel {
+  //       omp.distribute {
+  //         omp.wsloop {
+  //           omp.loop_nest ... { ... }
+  //         } {omp.composite}
+  //       } {omp.composite}
+  //       omp.terminator
+  //     } {omp.composite}
+  //     omp.terminator
+  //   }
+  //   omp.terminator
+  // }
+
+  Operation *capturedOp = getInnermostCapturedOmpOp();
+  if (!isa_and_present<LoopNestOp>(capturedOp))
+    return false;
+
+  Operation *workshareOp = capturedOp->getParentOp();
+
+  // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
+  if (isa_and_present<SimdOp>(workshareOp))
+    workshareOp = workshareOp->getParentOp();
+
+  if (!isa_and_present<WsloopOp>(workshareOp))
+    return false;
+
+  Operation *distributeOp = workshareOp->getParentOp();
+  if (!isa_and_present<DistributeOp>(distributeOp))
+    return false;
+
+  Operation *parallelOp = distributeOp->getParentOp();
+  if (!isa_and_present<ParallelOp>(parallelOp))
+    return false;
+
+  Operation *teamsOp = parallelOp->getParentOp();
+  if (!isa_and_present<TeamsOp>(teamsOp))
+    return false;
+
+  return teamsOp->getParentOp() == (*this);
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da11ee9960e1f9..cbcbeea4ab9225 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -174,6 +174,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.getHint())
       op.emitWarning("hint clause discarded");
   };
+  auto checkHostEval = [&todo](auto op, LogicalResult &result) {
+    if (!op.getHostEvalVars().empty())
+      result = todo("host_eval");
+  };
   auto checkIf = [&todo](auto op, LogicalResult &result) {
     if (op.getIfExpr())
       result = todo("if");
@@ -291,6 +295,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkAllocate(op, result);
         checkDevice(op, result);
         checkHasDeviceAddr(op, result);
+        checkHostEval(op, result);
         checkIf(op, result);
         checkInReduction(op, result);
         checkIsDevicePtr(op, result);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index aa41eea44f3ef4..216b3d6d806906 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2138,11 +2138,79 @@ func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
 
 // -----
 
+func.func @omp_target_multiple_teams() {
+  // expected-error @below {{target containing multiple 'omp.teams' nested ops}}
+  omp.target {
+    omp.teams {
+      omp.terminator
+    }
+    omp.teams {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval(%x : !llvm.ptr) {
+  // expected-error @below {{op host_eval argument illegal use in 'llvm.load' operation}}
+  omp.target host_eval(%x -> %arg0 : !llvm.ptr) {
+    %0 = llvm.load %arg0 : !llvm.ptr -> f32
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_teams(%x : i1) {
+  // expected-error @below {{op host_eval argument only legal as 'num_teams' and 'thread_limit' in 'omp.teams'}}
+  omp.target host_eval(%x -> %arg0 : i1) {
+    omp.teams if(%arg0) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_parallel(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as 'num_threads' in 'omp.parallel' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.parallel num_threads(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_host_eval_loop(%x : i32) {
+  // expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}}
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.wsloop {
+      omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+        omp.yield
+      }
+    }
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
 func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4f5cc696cada81..5a1b184fd3b94a 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -770,7 +770,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -2750,6 +2750,42 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
   return
 }
 
+func.func @omp_target_host_eval(%x : i32) {
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
+  // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams num_teams(to %arg0 : i32) thread_limit(%arg0 : i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+
+  // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
+  // CHECK: omp.teams
+  // CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
+  // CHECK: omp.distribute {
+  // CHECK: omp.wsloop {
+  // CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) {
+  omp.target host_eval(%x -> %arg0 : i32) {
+    omp.teams {
+      omp.parallel num_threads(%arg0 : i32) {
+        omp.distribute {
+          omp.wsloop {
+            omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
+              omp.yield
+            }
+          } {omp.composite}
+        } {omp.composite}
+        omp.terminator
+      } {omp.composite}
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
 // CHECK-LABEL: omp_loop
 func.func @omp_loop(%lb : index, %ub : index, %step : index) {
   // CHECK: o...
[truncated]

Copy link
Contributor

@mjklemm mjklemm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@skatrak skatrak force-pushed the users/skatrak/host-eval-01-clause branch from 0a47845 to d74cf35 Compare December 4, 2024 12:23
@skatrak skatrak force-pushed the users/skatrak/host-eval-02-mlir branch from 26fbb25 to 27ffa9f Compare December 4, 2024 13:16
@skatrak skatrak force-pushed the users/skatrak/host-eval-01-clause branch from d74cf35 to 5efde4c Compare January 8, 2025 14:14
@skatrak skatrak force-pushed the users/skatrak/host-eval-02-mlir branch from 27ffa9f to bd7fa37 Compare January 8, 2025 14:47
Base automatically changed from users/skatrak/host-eval-01-clause to main January 14, 2025 10:19
This patch adds the `host_eval` clause to the `omp.target` operation.
Additionally, it updates its op verifier to make sure all uses of block
arguments defined by this clause fall within one of the few cases where they
are allowed.

MLIR to LLVM IR translation fails on translation of this clause with a
not-yet-implemented error.
@skatrak skatrak force-pushed the users/skatrak/host-eval-02-mlir branch from bd7fa37 to 5f57b94 Compare January 14, 2025 10:20
@skatrak skatrak merged commit 9d7d8d2 into main Jan 14, 2025
5 of 7 checks passed
@skatrak skatrak deleted the users/skatrak/host-eval-02-mlir branch January 14, 2025 10:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants