@@ -672,8 +672,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
672
672
return parser.parseRegion (region, entryBlockArgs);
673
673
}
674
674
675
- static ParseResult parseInReductionMapPrivateRegion (
675
+ static ParseResult parseHostEvalInReductionMapPrivateRegion (
676
676
OpAsmParser &parser, Region ®ion,
677
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
678
+ SmallVectorImpl<Type> &hostEvalTypes,
677
679
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
678
680
SmallVectorImpl<Type> &inReductionTypes,
679
681
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -682,6 +684,7 @@ static ParseResult parseInReductionMapPrivateRegion(
682
684
llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateVars,
683
685
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms) {
684
686
AllRegionParseArgs args;
687
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
685
688
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
686
689
inReductionByref, inReductionSyms);
687
690
args.mapArgs .emplace (mapVars, mapTypes);
@@ -896,12 +899,14 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
896
899
p.printRegion (region, /* printEntryBlockArgs=*/ false );
897
900
}
898
901
899
- static void printInReductionMapPrivateRegion (
900
- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
902
+ static void printHostEvalInReductionMapPrivateRegion (
903
+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
904
+ TypeRange hostEvalTypes, ValueRange inReductionVars,
901
905
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
902
906
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
903
907
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms) {
904
908
AllRegionPrintArgs args;
909
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
905
910
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
906
911
inReductionByref, inReductionSyms);
907
912
args.mapArgs .emplace (mapVars, mapTypes);
@@ -1685,7 +1690,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
1685
1690
// inReductionByref, inReductionSyms.
1686
1691
TargetOp::build (builder, state, /* allocate_vars=*/ {}, /* allocator_vars=*/ {},
1687
1692
makeArrayAttr (ctx, clauses.dependKinds ), clauses.dependVars ,
1688
- clauses.device , clauses.hasDeviceAddrVars , clauses.ifExpr ,
1693
+ clauses.device , clauses.hasDeviceAddrVars ,
1694
+ clauses.hostEvalVars , clauses.ifExpr ,
1689
1695
/* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1690
1696
/* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1691
1697
clauses.mapVars , clauses.nowait , clauses.privateVars ,
@@ -1699,6 +1705,159 @@ LogicalResult TargetOp::verify() {
1699
1705
: verifyMapClause (*this , getMapVars ());
1700
1706
}
1701
1707
1708
+ LogicalResult TargetOp::verifyRegions () {
1709
+ auto teamsOps = getOps<TeamsOp>();
1710
+ if (std::distance (teamsOps.begin (), teamsOps.end ()) > 1 )
1711
+ return emitError (" target containing multiple 'omp.teams' nested ops" );
1712
+
1713
+ // Check that host_eval values are only used in legal ways.
1714
+ bool isTargetSPMD = isTargetSPMDLoop ();
1715
+ for (Value hostEvalArg :
1716
+ cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1717
+ for (Operation *user : hostEvalArg.getUsers ()) {
1718
+ if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1719
+ if (llvm::is_contained ({teamsOp.getNumTeamsLower (),
1720
+ teamsOp.getNumTeamsUpper (),
1721
+ teamsOp.getThreadLimit ()},
1722
+ hostEvalArg))
1723
+ continue ;
1724
+
1725
+ return emitOpError () << " host_eval argument only legal as 'num_teams' "
1726
+ " and 'thread_limit' in 'omp.teams'" ;
1727
+ }
1728
+ if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1729
+ if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads ())
1730
+ continue ;
1731
+
1732
+ return emitOpError ()
1733
+ << " host_eval argument only legal as 'num_threads' in "
1734
+ " 'omp.parallel' when representing target SPMD" ;
1735
+ }
1736
+ if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1737
+ if (isTargetSPMD &&
1738
+ (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1739
+ llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1740
+ llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1741
+ continue ;
1742
+
1743
+ return emitOpError ()
1744
+ << " host_eval argument only legal as loop bounds and steps in "
1745
+ " 'omp.loop_nest' when representing target SPMD" ;
1746
+ }
1747
+
1748
+ return emitOpError () << " host_eval argument illegal use in '"
1749
+ << user->getName () << " ' operation" ;
1750
+ }
1751
+ }
1752
+ return success ();
1753
+ }
1754
+
1755
+ // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1756
+ // / effects, but don't include a memory write effect.
1757
+ static bool siblingAllowedInCapture (Operation *op) {
1758
+ if (!op)
1759
+ return false ;
1760
+
1761
+ bool isOmpDialect =
1762
+ op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1763
+ op->getDialect ();
1764
+
1765
+ if (isOmpDialect)
1766
+ return op->hasTrait <OpTrait::IsTerminator>();
1767
+
1768
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1769
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1770
+ memOp.getEffects (effects);
1771
+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1772
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1773
+ isa<SideEffects::AutomaticAllocationScopeResource>(
1774
+ effect.getResource ());
1775
+ });
1776
+ }
1777
+ return true ;
1778
+ }
1779
+
1780
+ Operation *TargetOp::getInnermostCapturedOmpOp () {
1781
+ Dialect *ompDialect = (*this )->getDialect ();
1782
+ Operation *capturedOp = nullptr ;
1783
+
1784
+ // Process in pre-order to check operations from outermost to innermost,
1785
+ // ensuring we only enter the region of an operation if it meets the criteria
1786
+ // for being captured. We stop the exploration of nested operations as soon as
1787
+ // we process a region holding no operations to be captured.
1788
+ walk<WalkOrder::PreOrder>([&](Operation *op) {
1789
+ if (op == *this )
1790
+ return WalkResult::advance ();
1791
+
1792
+ // Ignore operations of other dialects or omp operations with no regions,
1793
+ // because these will only be checked if they are siblings of an omp
1794
+ // operation that can potentially be captured.
1795
+ bool isOmpDialect = op->getDialect () == ompDialect;
1796
+ bool hasRegions = op->getNumRegions () > 0 ;
1797
+ if (!isOmpDialect || !hasRegions)
1798
+ return WalkResult::skip ();
1799
+
1800
+ // Don't capture this op if it has a not-allowed sibling, and stop recursing
1801
+ // into nested operations.
1802
+ for (Operation &sibling : op->getParentRegion ()->getOps ())
1803
+ if (&sibling != op && !siblingAllowedInCapture (&sibling))
1804
+ return WalkResult::interrupt ();
1805
+
1806
+ // Don't continue capturing nested operations if we reach an omp.loop_nest.
1807
+ // Otherwise, process the contents of this operation.
1808
+ capturedOp = op;
1809
+ return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt ()
1810
+ : WalkResult::advance ();
1811
+ });
1812
+
1813
+ return capturedOp;
1814
+ }
1815
+
1816
+ bool TargetOp::isTargetSPMDLoop () {
1817
+ // The expected MLIR representation for a target SPMD loop is:
1818
+ // omp.target {
1819
+ // omp.teams {
1820
+ // omp.parallel {
1821
+ // omp.distribute {
1822
+ // omp.wsloop {
1823
+ // omp.loop_nest ... { ... }
1824
+ // } {omp.composite}
1825
+ // } {omp.composite}
1826
+ // omp.terminator
1827
+ // } {omp.composite}
1828
+ // omp.terminator
1829
+ // }
1830
+ // omp.terminator
1831
+ // }
1832
+
1833
+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1834
+ if (!isa_and_present<LoopNestOp>(capturedOp))
1835
+ return false ;
1836
+
1837
+ Operation *workshareOp = capturedOp->getParentOp ();
1838
+
1839
+ // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1840
+ if (isa_and_present<SimdOp>(workshareOp))
1841
+ workshareOp = workshareOp->getParentOp ();
1842
+
1843
+ if (!isa_and_present<WsloopOp>(workshareOp))
1844
+ return false ;
1845
+
1846
+ Operation *distributeOp = workshareOp->getParentOp ();
1847
+ if (!isa_and_present<DistributeOp>(distributeOp))
1848
+ return false ;
1849
+
1850
+ Operation *parallelOp = distributeOp->getParentOp ();
1851
+ if (!isa_and_present<ParallelOp>(parallelOp))
1852
+ return false ;
1853
+
1854
+ Operation *teamsOp = parallelOp->getParentOp ();
1855
+ if (!isa_and_present<TeamsOp>(teamsOp))
1856
+ return false ;
1857
+
1858
+ return teamsOp->getParentOp () == (*this );
1859
+ }
1860
+
1702
1861
// ===----------------------------------------------------------------------===//
1703
1862
// ParallelOp
1704
1863
// ===----------------------------------------------------------------------===//
0 commit comments