Skip to content

Commit 26fbb25

Browse files
committed
[MLIR][OpenMP] Add host_eval clause to omp.target
This patch adds the `host_eval` clause to the `omp.target` operation. Additionally, it updates its op verifier to make sure all uses of block arguments defined by this clause fall within one of the few cases where they are allowed. MLIR to LLVM IR translation fails on translation of this clause with a not-yet-implemented error.
1 parent 0a47845 commit 26fbb25

File tree

7 files changed

+369
-13
lines changed

7 files changed

+369
-13
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,58 @@ omp.parallel ... {
523523
omp.terminator
524524
} {omp.composite}
525525
```
526+
527+
## Host-Evaluated Clauses in Target Regions
528+
529+
The `omp.target` operation, which represents the OpenMP `target` construct, is
530+
marked with the `IsolatedFromAbove` trait. This means that, inside of its
531+
region, no MLIR values defined outside of the op itself can be used. This is
532+
consistent with the OpenMP specification of the `target` construct, which
533+
mandates that all host device values used inside of the `target` region must
534+
either be privatized (data-sharing) or mapped (data-mapping).
535+
536+
Normally, clauses applied to a construct are evaluated before entering that
537+
construct. Further, in some cases, the OpenMP specification stipulates that
538+
clauses be evaluated _on the host device_ on entry to a parent `target`
539+
construct. In particular, the `num_teams` and `thread_limit` clauses of the
540+
`teams` construct must be evaluated on the host device if it's nested inside or
541+
combined with a `target` construct.
542+
543+
Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
544+
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
545+
`target teams distribute parallel {do,for}` in OpenMP), which requires
546+
specifying in advance what the total trip count of the loop is. Consequently, it
547+
is also beneficial to evaluate the trip count on the host device prior to the
548+
kernel launch.
549+
550+
These host-evaluated values in MLIR would need to be placed outside of the
551+
`omp.target` region and also attached to the corresponding nested operations,
552+
which is not possible because of the `IsolatedFromAbove` trait. The solution
553+
implemented to address this problem has been to introduce the `host_eval`
554+
argument to the `omp.target` operation. It works similarly to a `map` clause,
555+
but its only intended use is to forward host-evaluated values to their
556+
corresponding operation inside of the region. Any uses outside of the previously
557+
described result in a verifier error.
558+
559+
```mlir
560+
// Initialize %0, %1, %2, %3...
561+
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
562+
omp.teams num_teams(to %nt : i32) {
563+
omp.parallel {
564+
omp.distribute {
565+
omp.wsloop {
566+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
567+
// ...
568+
omp.yield
569+
}
570+
omp.terminator
571+
} {omp.composite}
572+
omp.terminator
573+
} {omp.composite}
574+
omp.terminator
575+
} {omp.composite}
576+
omp.terminator
577+
}
578+
omp.terminator
579+
}
580+
```

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1166,9 +1166,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
11661166
], clauses = [
11671167
// TODO: Complete clause list (defaultmap, uses_allocators).
11681168
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_DeviceClause,
1169-
OpenMP_HasDeviceAddrClause, OpenMP_IfClause, OpenMP_InReductionClause,
1170-
OpenMP_IsDevicePtrClause, OpenMP_MapClauseSkip<assemblyFormat = true>,
1171-
OpenMP_NowaitClause, OpenMP_PrivateClause, OpenMP_ThreadLimitClause
1169+
OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause, OpenMP_IfClause,
1170+
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
1171+
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
1172+
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
11721173
], singleRegion = true> {
11731174
let summary = "target construct";
11741175
let description = [{
@@ -1186,16 +1187,34 @@ def TargetOp : OpenMP_Op<"target", traits = [
11861187

11871188
let extraClassDeclaration = [{
11881189
unsigned numMapBlockArgs() { return getMapVars().size(); }
1190+
1191+
/// Returns the innermost OpenMP dialect operation captured by this target
1192+
/// construct. For an operation to be detected as captured, it must be
1193+
/// inside a (possibly multi-level) nest of OpenMP dialect operation's
1194+
/// regions where none of these levels contain other operations considered
1195+
/// not-allowed for these purposes (i.e. only terminator operations are
1196+
/// allowed from the OpenMP dialect, and other dialect's operations are
1197+
/// allowed as long as they don't have a memory write effect).
1198+
///
1199+
/// If there are omp.loop_nest operations in the sequence of nested
1200+
/// operations, the top level one will be the one captured.
1201+
Operation *getInnermostCapturedOmpOp();
1202+
1203+
/// Checks whether this target region represents the MLIR equivalent to a
1204+
/// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
1205+
bool isTargetSPMDLoop();
11891206
}] # clausesExtraClassDeclaration;
11901207

11911208
let assemblyFormat = clausesAssemblyFormat # [{
1192-
custom<InReductionMapPrivateRegion>(
1193-
$region, $in_reduction_vars, type($in_reduction_vars),
1194-
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
1195-
$private_vars, type($private_vars), $private_syms) attr-dict
1209+
custom<HostEvalInReductionMapPrivateRegion>(
1210+
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
1211+
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
1212+
$map_vars, type($map_vars), $private_vars, type($private_vars),
1213+
$private_syms) attr-dict
11961214
}];
11971215

11981216
let hasVerifier = 1;
1217+
let hasRegionVerifier = 1;
11991218
}
12001219

12011220

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 163 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
672672
return parser.parseRegion(region, entryBlockArgs);
673673
}
674674

675-
static ParseResult parseInReductionMapPrivateRegion(
675+
static ParseResult parseHostEvalInReductionMapPrivateRegion(
676676
OpAsmParser &parser, Region &region,
677+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
678+
SmallVectorImpl<Type> &hostEvalTypes,
677679
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
678680
SmallVectorImpl<Type> &inReductionTypes,
679681
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -682,6 +684,7 @@ static ParseResult parseInReductionMapPrivateRegion(
682684
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
683685
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
684686
AllRegionParseArgs args;
687+
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
685688
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
686689
inReductionByref, inReductionSyms);
687690
args.mapArgs.emplace(mapVars, mapTypes);
@@ -896,12 +899,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
896899
p.printRegion(region, /*printEntryBlockArgs=*/false);
897900
}
898901

899-
static void printInReductionMapPrivateRegion(
900-
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
902+
static void printHostEvalInReductionMapPrivateRegion(
903+
OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
904+
TypeRange hostEvalTypes, ValueRange inReductionVars,
901905
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
902906
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
903907
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
904908
AllRegionPrintArgs args;
909+
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
905910
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
906911
inReductionByref, inReductionSyms);
907912
args.mapArgs.emplace(mapVars, mapTypes);
@@ -1685,7 +1690,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
16851690
// inReductionByref, inReductionSyms.
16861691
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
16871692
makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
1688-
clauses.device, clauses.hasDeviceAddrVars, clauses.ifExpr,
1693+
clauses.device, clauses.hasDeviceAddrVars,
1694+
clauses.hostEvalVars, clauses.ifExpr,
16891695
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
16901696
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
16911697
clauses.mapVars, clauses.nowait, clauses.privateVars,
@@ -1699,6 +1705,159 @@ LogicalResult TargetOp::verify() {
16991705
: verifyMapClause(*this, getMapVars());
17001706
}
17011707

1708+
LogicalResult TargetOp::verifyRegions() {
1709+
auto teamsOps = getOps<TeamsOp>();
1710+
if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
1711+
return emitError("target containing multiple 'omp.teams' nested ops");
1712+
1713+
// Check that host_eval values are only used in legal ways.
1714+
bool isTargetSPMD = isTargetSPMDLoop();
1715+
for (Value hostEvalArg :
1716+
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
1717+
for (Operation *user : hostEvalArg.getUsers()) {
1718+
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1719+
if (llvm::is_contained({teamsOp.getNumTeamsLower(),
1720+
teamsOp.getNumTeamsUpper(),
1721+
teamsOp.getThreadLimit()},
1722+
hostEvalArg))
1723+
continue;
1724+
1725+
return emitOpError() << "host_eval argument only legal as 'num_teams' "
1726+
"and 'thread_limit' in 'omp.teams'";
1727+
}
1728+
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1729+
if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
1730+
continue;
1731+
1732+
return emitOpError()
1733+
<< "host_eval argument only legal as 'num_threads' in "
1734+
"'omp.parallel' when representing target SPMD";
1735+
}
1736+
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1737+
if (isTargetSPMD &&
1738+
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
1739+
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
1740+
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
1741+
continue;
1742+
1743+
return emitOpError()
1744+
<< "host_eval argument only legal as loop bounds and steps in "
1745+
"'omp.loop_nest' when representing target SPMD";
1746+
}
1747+
1748+
return emitOpError() << "host_eval argument illegal use in '"
1749+
<< user->getName() << "' operation";
1750+
}
1751+
}
1752+
return success();
1753+
}
1754+
1755+
/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1756+
/// effects, but don't include a memory write effect.
1757+
static bool siblingAllowedInCapture(Operation *op) {
1758+
if (!op)
1759+
return false;
1760+
1761+
bool isOmpDialect =
1762+
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1763+
op->getDialect();
1764+
1765+
if (isOmpDialect)
1766+
return op->hasTrait<OpTrait::IsTerminator>();
1767+
1768+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1769+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
1770+
memOp.getEffects(effects);
1771+
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1772+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
1773+
isa<SideEffects::AutomaticAllocationScopeResource>(
1774+
effect.getResource());
1775+
});
1776+
}
1777+
return true;
1778+
}
1779+
1780+
Operation *TargetOp::getInnermostCapturedOmpOp() {
1781+
Dialect *ompDialect = (*this)->getDialect();
1782+
Operation *capturedOp = nullptr;
1783+
1784+
// Process in pre-order to check operations from outermost to innermost,
1785+
// ensuring we only enter the region of an operation if it meets the criteria
1786+
// for being captured. We stop the exploration of nested operations as soon as
1787+
// we process a region holding no operations to be captured.
1788+
walk<WalkOrder::PreOrder>([&](Operation *op) {
1789+
if (op == *this)
1790+
return WalkResult::advance();
1791+
1792+
// Ignore operations of other dialects or omp operations with no regions,
1793+
// because these will only be checked if they are siblings of an omp
1794+
// operation that can potentially be captured.
1795+
bool isOmpDialect = op->getDialect() == ompDialect;
1796+
bool hasRegions = op->getNumRegions() > 0;
1797+
if (!isOmpDialect || !hasRegions)
1798+
return WalkResult::skip();
1799+
1800+
// Don't capture this op if it has a not-allowed sibling, and stop recursing
1801+
// into nested operations.
1802+
for (Operation &sibling : op->getParentRegion()->getOps())
1803+
if (&sibling != op && !siblingAllowedInCapture(&sibling))
1804+
return WalkResult::interrupt();
1805+
1806+
// Don't continue capturing nested operations if we reach an omp.loop_nest.
1807+
// Otherwise, process the contents of this operation.
1808+
capturedOp = op;
1809+
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
1810+
: WalkResult::advance();
1811+
});
1812+
1813+
return capturedOp;
1814+
}
1815+
1816+
bool TargetOp::isTargetSPMDLoop() {
1817+
// The expected MLIR representation for a target SPMD loop is:
1818+
// omp.target {
1819+
// omp.teams {
1820+
// omp.parallel {
1821+
// omp.distribute {
1822+
// omp.wsloop {
1823+
// omp.loop_nest ... { ... }
1824+
// } {omp.composite}
1825+
// } {omp.composite}
1826+
// omp.terminator
1827+
// } {omp.composite}
1828+
// omp.terminator
1829+
// }
1830+
// omp.terminator
1831+
// }
1832+
1833+
Operation *capturedOp = getInnermostCapturedOmpOp();
1834+
if (!isa_and_present<LoopNestOp>(capturedOp))
1835+
return false;
1836+
1837+
Operation *workshareOp = capturedOp->getParentOp();
1838+
1839+
// Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1840+
if (isa_and_present<SimdOp>(workshareOp))
1841+
workshareOp = workshareOp->getParentOp();
1842+
1843+
if (!isa_and_present<WsloopOp>(workshareOp))
1844+
return false;
1845+
1846+
Operation *distributeOp = workshareOp->getParentOp();
1847+
if (!isa_and_present<DistributeOp>(distributeOp))
1848+
return false;
1849+
1850+
Operation *parallelOp = distributeOp->getParentOp();
1851+
if (!isa_and_present<ParallelOp>(parallelOp))
1852+
return false;
1853+
1854+
Operation *teamsOp = parallelOp->getParentOp();
1855+
if (!isa_and_present<TeamsOp>(teamsOp))
1856+
return false;
1857+
1858+
return teamsOp->getParentOp() == (*this);
1859+
}
1860+
17021861
//===----------------------------------------------------------------------===//
17031862
// ParallelOp
17041863
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
174174
if (op.getHint())
175175
op.emitWarning("hint clause discarded");
176176
};
177+
auto checkHostEval = [&todo](auto op, LogicalResult &result) {
178+
if (!op.getHostEvalVars().empty())
179+
result = todo("host_eval");
180+
};
177181
auto checkIf = [&todo](auto op, LogicalResult &result) {
178182
if (op.getIfExpr())
179183
result = todo("if");
@@ -291,6 +295,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
291295
checkAllocate(op, result);
292296
checkDevice(op, result);
293297
checkHasDeviceAddr(op, result);
298+
checkHostEval(op, result);
294299
checkIf(op, result);
295300
checkInReduction(op, result);
296301
checkIsDevicePtr(op, result);

0 commit comments

Comments
 (0)