Skip to content

Commit bd7fa37

Browse files
committed
[MLIR][OpenMP] Add host_eval clause to omp.target
This patch adds the `host_eval` clause to the `omp.target` operation. Additionally, it updates its op verifier to make sure all uses of block arguments defined by this clause fall within one of the few cases where they are allowed. MLIR to LLVM IR translation fails on translation of this clause with a not-yet-implemented error.
1 parent 5efde4c commit bd7fa37

File tree

8 files changed

+446
-19
lines changed

8 files changed

+446
-19
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ introduction of private copies of the same underlying variable defined outside
298298
the MLIR operation the clause is attached to. Currently, clauses with this
299299
property can be classified into three main categories:
300300
- Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
301-
specification), `map`, `use_device_addr` and `use_device_ptr`.
301+
specification: [see more](#host-evaluated-clauses-in-target-regions)), `map`,
302+
`use_device_addr` and `use_device_ptr`.
302303
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
303304
- Privatization clauses: `private`.
304305

@@ -523,3 +524,58 @@ omp.parallel ... {
523524
omp.terminator
524525
} {omp.composite}
525526
```
527+
528+
## Host-Evaluated Clauses in Target Regions
529+
530+
The `omp.target` operation, which represents the OpenMP `target` construct, is
531+
marked with the `IsolatedFromAbove` trait. This means that, inside of its
532+
region, no MLIR values defined outside of the op itself can be used. This is
533+
consistent with the OpenMP specification of the `target` construct, which
534+
mandates that all host device values used inside of the `target` region must
535+
either be privatized (data-sharing) or mapped (data-mapping).
536+
537+
Normally, clauses applied to a construct are evaluated before entering that
538+
construct. Further, in some cases, the OpenMP specification stipulates that
539+
clauses be evaluated _on the host device_ on entry to a parent `target`
540+
construct. In particular, the `num_teams` and `thread_limit` clauses of the
541+
`teams` construct must be evaluated on the host device if it's nested inside or
542+
combined with a `target` construct.
543+
544+
Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
545+
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
546+
`target teams distribute parallel {do,for}` in OpenMP), which requires
547+
specifying in advance what the total trip count of the loop is. Consequently, it
548+
is also beneficial to evaluate the trip count on the host device prior to the
549+
kernel launch.
550+
551+
These host-evaluated values in MLIR would need to be placed outside of the
552+
`omp.target` region and also attached to the corresponding nested operations,
553+
which is not possible because of the `IsolatedFromAbove` trait. The solution
554+
implemented to address this problem has been to introduce the `host_eval`
555+
argument to the `omp.target` operation. It works similarly to a `map` clause,
556+
but its only intended use is to forward host-evaluated values to their
557+
corresponding operation inside of the region. Any uses outside of the previously
558+
described result in a verifier error.
559+
560+
```mlir
561+
// Initialize %0, %1, %2, %3...
562+
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
563+
omp.teams num_teams(to %nt : i32) {
564+
omp.parallel {
565+
omp.distribute {
566+
omp.wsloop {
567+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
568+
// ...
569+
omp.yield
570+
}
571+
omp.terminator
572+
} {omp.composite}
573+
omp.terminator
574+
} {omp.composite}
575+
omp.terminator
576+
} {omp.composite}
577+
omp.terminator
578+
}
579+
omp.terminator
580+
}
581+
```

mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/SymbolTable.h"
2323
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2424
#include "mlir/Interfaces/SideEffectInterfaces.h"
25+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
2526

2627
#define GET_TYPEDEF_CLASSES
2728
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,10 +1224,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
12241224
], clauses = [
12251225
// TODO: Complete clause list (defaultmap, uses_allocators).
12261226
OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
1227-
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_IfClause,
1228-
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
1227+
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
1228+
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
12291229
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
1230-
OpenMP_PrivateClause, OpenMP_ThreadLimitClause,
1230+
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
12311231
], singleRegion = true> {
12321232
let summary = "target construct";
12331233
let description = [{
@@ -1269,17 +1269,34 @@ def TargetOp : OpenMP_Op<"target", traits = [
12691269

12701270
return getMapVars()[mapInfoOpIdx];
12711271
}
1272+
1273+
/// Returns the innermost OpenMP dialect operation captured by this target
1274+
/// construct. For an operation to be detected as captured, it must be
1275+
/// inside a (possibly multi-level) nest of OpenMP dialect operation's
1276+
/// regions where none of these levels contain other operations considered
1277+
/// not-allowed for these purposes (i.e. only terminator operations are
1278+
/// allowed from the OpenMP dialect, and other dialect's operations are
1279+
/// allowed as long as they don't have a memory write effect).
1280+
///
1281+
/// If there are omp.loop_nest operations in the sequence of nested
1282+
/// operations, the top level one will be the one captured.
1283+
Operation *getInnermostCapturedOmpOp();
1284+
1285+
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
1286+
/// contents of the target region.
1287+
llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
12721288
}] # clausesExtraClassDeclaration;
12731289

12741290
let assemblyFormat = clausesAssemblyFormat # [{
1275-
custom<InReductionMapPrivateRegion>(
1276-
$region, $in_reduction_vars, type($in_reduction_vars),
1277-
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
1278-
$private_vars, type($private_vars), $private_syms, $private_maps)
1279-
attr-dict
1291+
custom<HostEvalInReductionMapPrivateRegion>(
1292+
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
1293+
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
1294+
$map_vars, type($map_vars), $private_vars, type($private_vars),
1295+
$private_syms, $private_maps) attr-dict
12801296
}];
12811297

12821298
let hasVerifier = 1;
1299+
let hasRegionVerifier = 1;
12831300
}
12841301

12851302

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 198 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/ADT/StringRef.h"
3232
#include "llvm/ADT/TypeSwitch.h"
3333
#include "llvm/Frontend/OpenMP/OMPConstants.h"
34+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3435
#include <cstddef>
3536
#include <iterator>
3637
#include <optional>
@@ -691,8 +692,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
691692
return parser.parseRegion(region, entryBlockArgs);
692693
}
693694

694-
static ParseResult parseInReductionMapPrivateRegion(
695+
static ParseResult parseHostEvalInReductionMapPrivateRegion(
695696
OpAsmParser &parser, Region &region,
697+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
698+
SmallVectorImpl<Type> &hostEvalTypes,
696699
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
697700
SmallVectorImpl<Type> &inReductionTypes,
698701
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -702,6 +705,7 @@ static ParseResult parseInReductionMapPrivateRegion(
702705
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
703706
DenseI64ArrayAttr &privateMaps) {
704707
AllRegionParseArgs args;
708+
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
705709
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
706710
inReductionByref, inReductionSyms);
707711
args.mapArgs.emplace(mapVars, mapTypes);
@@ -931,13 +935,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
931935
p.printRegion(region, /*printEntryBlockArgs=*/false);
932936
}
933937

934-
static void printInReductionMapPrivateRegion(
935-
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
938+
static void printHostEvalInReductionMapPrivateRegion(
939+
OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
940+
TypeRange hostEvalTypes, ValueRange inReductionVars,
936941
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
937942
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
938943
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
939944
DenseI64ArrayAttr privateMaps) {
940945
AllRegionPrintArgs args;
946+
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
941947
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
942948
inReductionByref, inReductionSyms);
943949
args.mapArgs.emplace(mapVars, mapTypes);
@@ -1720,11 +1726,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
17201726
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
17211727
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
17221728
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
1723-
clauses.ifExpr, /*in_reduction_vars=*/{},
1724-
/*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr,
1725-
clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait,
1726-
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
1727-
clauses.threadLimit, /*private_maps=*/nullptr);
1729+
clauses.hostEvalVars, clauses.ifExpr,
1730+
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
1731+
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
1732+
clauses.mapVars, clauses.nowait, clauses.privateVars,
1733+
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
1734+
/*private_maps=*/nullptr);
17281735
}
17291736

17301737
LogicalResult TargetOp::verify() {
@@ -1742,6 +1749,189 @@ LogicalResult TargetOp::verify() {
17421749
return verifyPrivateVarsMapping(*this);
17431750
}
17441751

1752+
LogicalResult TargetOp::verifyRegions() {
1753+
auto teamsOps = getOps<TeamsOp>();
1754+
if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1755+
return emitError("target containing multiple 'omp.teams' nested ops");
1756+
1757+
// Check that host_eval values are only used in legal ways.
1758+
llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
1759+
for (Value hostEvalArg :
1760+
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1761+
for (Operation *user : hostEvalArg.getUsers()) {
1762+
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1763+
if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1764+
teamsOp.getNumTeamsUpper(),
1765+
teamsOp.getThreadLimit()},
1766+
hostEvalArg))
1767+
continue;
1768+
1769+
return emitOpError() << "host_eval argument only legal as 'num_teams' "
1770+
"and 'thread_limit' in 'omp.teams'";
1771+
}
1772+
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1773+
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1774+
hostEvalArg == parallelOp.getNumThreads())
1775+
continue;
1776+
1777+
return emitOpError()
1778+
<< "host_eval argument only legal as 'num_threads' in "
1779+
"'omp.parallel' when representing target SPMD";
1780+
}
1781+
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1782+
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1783+
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1784+
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1785+
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1786+
continue;
1787+
1788+
return emitOpError() << "host_eval argument only legal as loop bounds "
1789+
"and steps in 'omp.loop_nest' when "
1790+
"representing target SPMD or Generic-SPMD";
1791+
}
1792+
1793+
return emitOpError() << "host_eval argument illegal use in '"
1794+
<< user->getName() << "' operation";
1795+
}
1796+
}
1797+
return success();
1798+
}
1799+
1800+
/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1801+
/// effects, but don't include a memory write effect.
1802+
static bool siblingAllowedInCapture(Operation *op) {
1803+
if (!op)
1804+
return false;
1805+
1806+
bool isOmpDialect =
1807+
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1808+
op->getDialect();
1809+
1810+
if (isOmpDialect)
1811+
return op->hasTrait<OpTrait::IsTerminator>();
1812+
1813+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1814+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
1815+
memOp.getEffects(effects);
1816+
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1817+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
1818+
isa<SideEffects::AutomaticAllocationScopeResource>(
1819+
effect.getResource());
1820+
});
1821+
}
1822+
return true;
1823+
}
1824+
1825+
Operation *TargetOp::getInnermostCapturedOmpOp() {
1826+
Dialect *ompDialect = (*this)->getDialect();
1827+
Operation *capturedOp = nullptr;
1828+
DominanceInfo domInfo;
1829+
1830+
// Process in pre-order to check operations from outermost to innermost,
1831+
// ensuring we only enter the region of an operation if it meets the criteria
1832+
// for being captured. We stop the exploration of nested operations as soon as
1833+
// we process a region holding no operations to be captured.
1834+
walk<WalkOrder::PreOrder>([&](Operation *op) {
1835+
if (op == *this)
1836+
return WalkResult::advance();
1837+
1838+
// Ignore operations of other dialects or omp operations with no regions,
1839+
// because these will only be checked if they are siblings of an omp
1840+
// operation that can potentially be captured.
1841+
bool isOmpDialect = op->getDialect() == ompDialect;
1842+
bool hasRegions = op->getNumRegions() > 0;
1843+
if (!isOmpDialect || !hasRegions)
1844+
return WalkResult::skip();
1845+
1846+
// This operation cannot be captured if it can be executed more than once
1847+
// (i.e. its block's successors can reach it) or if it's not guaranteed to
1848+
// be executed before all exits of the region (i.e. it doesn't dominate all
1849+
// blocks with no successors reachable from the entry block).
1850+
Region *parentRegion = op->getParentRegion();
1851+
Block *parentBlock = op->getBlock();
1852+
1853+
for (Block *successor : parentBlock->getSuccessors())
1854+
if (successor->isReachable(parentBlock))
1855+
return WalkResult::interrupt();
1856+
1857+
for (Block &block : *parentRegion)
1858+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1859+
!domInfo.dominates(parentBlock, &block))
1860+
return WalkResult::interrupt();
1861+
1862+
// Don't capture this op if it has a not-allowed sibling, and stop recursing
1863+
// into nested operations.
1864+
for (Operation &sibling : op->getParentRegion()->getOps())
1865+
if (&sibling != op && !siblingAllowedInCapture(&sibling))
1866+
return WalkResult::interrupt();
1867+
1868+
// Don't continue capturing nested operations if we reach an omp.loop_nest.
1869+
// Otherwise, process the contents of this operation.
1870+
capturedOp = op;
1871+
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1872+
: WalkResult::advance();
1873+
});
1874+
1875+
return capturedOp;
1876+
}
1877+
1878+
llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
1879+
using namespace llvm::omp;
1880+
1881+
// Make sure this region is capturing a loop. Otherwise, it's a generic
1882+
// kernel.
1883+
Operation *capturedOp = getInnermostCapturedOmpOp();
1884+
if (!isa_and_present<LoopNestOp>(capturedOp))
1885+
return OMP_TGT_EXEC_MODE_GENERIC;
1886+
1887+
SmallVector<LoopWrapperInterface> wrappers;
1888+
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
1889+
assert(!wrappers.empty());
1890+
1891+
// Ignore optional SIMD leaf construct.
1892+
auto *innermostWrapper = wrappers.begin();
1893+
if (isa<SimdOp>(innermostWrapper))
1894+
innermostWrapper = std::next(innermostWrapper);
1895+
1896+
long numWrappers = std::distance(innermostWrapper, wrappers.end());
1897+
1898+
// Detect Generic-SPMD: target-teams-distribute[-simd].
1899+
if (numWrappers == 1) {
1900+
if (!isa<DistributeOp>(innermostWrapper))
1901+
return OMP_TGT_EXEC_MODE_GENERIC;
1902+
1903+
Operation *teamsOp = (*innermostWrapper)->getParentOp();
1904+
if (!isa_and_present<TeamsOp>(teamsOp))
1905+
return OMP_TGT_EXEC_MODE_GENERIC;
1906+
1907+
if (teamsOp->getParentOp() == *this)
1908+
return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1909+
}
1910+
1911+
// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1912+
if (numWrappers == 2) {
1913+
if (!isa<WsloopOp>(innermostWrapper))
1914+
return OMP_TGT_EXEC_MODE_GENERIC;
1915+
1916+
innermostWrapper = std::next(innermostWrapper);
1917+
if (!isa<DistributeOp>(innermostWrapper))
1918+
return OMP_TGT_EXEC_MODE_GENERIC;
1919+
1920+
Operation *parallelOp = (*innermostWrapper)->getParentOp();
1921+
if (!isa_and_present<ParallelOp>(parallelOp))
1922+
return OMP_TGT_EXEC_MODE_GENERIC;
1923+
1924+
Operation *teamsOp = parallelOp->getParentOp();
1925+
if (!isa_and_present<TeamsOp>(teamsOp))
1926+
return OMP_TGT_EXEC_MODE_GENERIC;
1927+
1928+
if (teamsOp->getParentOp() == *this)
1929+
return OMP_TGT_EXEC_MODE_SPMD;
1930+
}
1931+
1932+
return OMP_TGT_EXEC_MODE_GENERIC;
1933+
}
1934+
17451935
//===----------------------------------------------------------------------===//
17461936
// ParallelOp
17471937
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)