31
31
#include " llvm/ADT/StringRef.h"
32
32
#include " llvm/ADT/TypeSwitch.h"
33
33
#include " llvm/Frontend/OpenMP/OMPConstants.h"
34
+ #include " llvm/Frontend/OpenMP/OMPDeviceConstants.h"
34
35
#include < cstddef>
35
36
#include < iterator>
36
37
#include < optional>
@@ -691,8 +692,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
691
692
return parser.parseRegion (region, entryBlockArgs);
692
693
}
693
694
694
- static ParseResult parseInReductionMapPrivateRegion (
695
+ static ParseResult parseHostEvalInReductionMapPrivateRegion (
695
696
OpAsmParser &parser, Region ®ion,
697
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
698
+ SmallVectorImpl<Type> &hostEvalTypes,
696
699
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
697
700
SmallVectorImpl<Type> &inReductionTypes,
698
701
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -702,6 +705,7 @@ static ParseResult parseInReductionMapPrivateRegion(
702
705
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
703
706
DenseI64ArrayAttr &privateMaps) {
704
707
AllRegionParseArgs args;
708
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
705
709
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
706
710
inReductionByref, inReductionSyms);
707
711
args.mapArgs .emplace (mapVars, mapTypes);
@@ -931,13 +935,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931
935
p.printRegion (region, /* printEntryBlockArgs=*/ false );
932
936
}
933
937
934
- static void printInReductionMapPrivateRegion (
935
- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
938
+ static void printHostEvalInReductionMapPrivateRegion (
939
+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
940
+ TypeRange hostEvalTypes, ValueRange inReductionVars,
936
941
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
937
942
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
938
943
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
939
944
DenseI64ArrayAttr privateMaps) {
940
945
AllRegionPrintArgs args;
946
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
941
947
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
942
948
inReductionByref, inReductionSyms);
943
949
args.mapArgs .emplace (mapVars, mapTypes);
@@ -1720,11 +1726,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
1720
1726
TargetOp::build (builder, state, /* allocate_vars=*/ {}, /* allocator_vars=*/ {},
1721
1727
clauses.bare , makeArrayAttr (ctx, clauses.dependKinds ),
1722
1728
clauses.dependVars , clauses.device , clauses.hasDeviceAddrVars ,
1723
- clauses.ifExpr , /* in_reduction_vars=*/ {},
1724
- /* in_reduction_byref=*/ nullptr , /* in_reduction_syms=*/ nullptr ,
1725
- clauses.isDevicePtrVars , clauses.mapVars , clauses.nowait ,
1726
- clauses.privateVars , makeArrayAttr (ctx, clauses.privateSyms ),
1727
- clauses.threadLimit , /* private_maps=*/ nullptr );
1729
+ clauses.hostEvalVars , clauses.ifExpr ,
1730
+ /* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1731
+ /* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1732
+ clauses.mapVars , clauses.nowait , clauses.privateVars ,
1733
+ makeArrayAttr (ctx, clauses.privateSyms ), clauses.threadLimit ,
1734
+ /* private_maps=*/ nullptr );
1728
1735
}
1729
1736
1730
1737
LogicalResult TargetOp::verify () {
@@ -1742,6 +1749,189 @@ LogicalResult TargetOp::verify() {
1742
1749
return verifyPrivateVarsMapping (*this );
1743
1750
}
1744
1751
1752
+ LogicalResult TargetOp::verifyRegions () {
1753
+ auto teamsOps = getOps<TeamsOp>();
1754
+ if (std::distance (teamsOps.begin (), teamsOps.end ()) > 1 )
1755
+ return emitError (" target containing multiple 'omp.teams' nested ops" );
1756
+
1757
+ // Check that host_eval values are only used in legal ways.
1758
+ llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags ();
1759
+ for (Value hostEvalArg :
1760
+ cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1761
+ for (Operation *user : hostEvalArg.getUsers ()) {
1762
+ if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1763
+ if (llvm::is_contained ({teamsOp.getNumTeamsLower (),
1764
+ teamsOp.getNumTeamsUpper (),
1765
+ teamsOp.getThreadLimit ()},
1766
+ hostEvalArg))
1767
+ continue ;
1768
+
1769
+ return emitOpError () << " host_eval argument only legal as 'num_teams' "
1770
+ " and 'thread_limit' in 'omp.teams'" ;
1771
+ }
1772
+ if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1773
+ if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1774
+ hostEvalArg == parallelOp.getNumThreads ())
1775
+ continue ;
1776
+
1777
+ return emitOpError ()
1778
+ << " host_eval argument only legal as 'num_threads' in "
1779
+ " 'omp.parallel' when representing target SPMD" ;
1780
+ }
1781
+ if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1782
+ if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1783
+ (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1784
+ llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1785
+ llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1786
+ continue ;
1787
+
1788
+ return emitOpError () << " host_eval argument only legal as loop bounds "
1789
+ " and steps in 'omp.loop_nest' when "
1790
+ " representing target SPMD or Generic-SPMD" ;
1791
+ }
1792
+
1793
+ return emitOpError () << " host_eval argument illegal use in '"
1794
+ << user->getName () << " ' operation" ;
1795
+ }
1796
+ }
1797
+ return success ();
1798
+ }
1799
+
1800
+ // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1801
+ // / effects, but don't include a memory write effect.
1802
+ static bool siblingAllowedInCapture (Operation *op) {
1803
+ if (!op)
1804
+ return false ;
1805
+
1806
+ bool isOmpDialect =
1807
+ op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1808
+ op->getDialect ();
1809
+
1810
+ if (isOmpDialect)
1811
+ return op->hasTrait <OpTrait::IsTerminator>();
1812
+
1813
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1814
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1815
+ memOp.getEffects (effects);
1816
+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1817
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1818
+ isa<SideEffects::AutomaticAllocationScopeResource>(
1819
+ effect.getResource ());
1820
+ });
1821
+ }
1822
+ return true ;
1823
+ }
1824
+
1825
+ Operation *TargetOp::getInnermostCapturedOmpOp () {
1826
+ Dialect *ompDialect = (*this )->getDialect ();
1827
+ Operation *capturedOp = nullptr ;
1828
+ DominanceInfo domInfo;
1829
+
1830
+ // Process in pre-order to check operations from outermost to innermost,
1831
+ // ensuring we only enter the region of an operation if it meets the criteria
1832
+ // for being captured. We stop the exploration of nested operations as soon as
1833
+ // we process a region holding no operations to be captured.
1834
+ walk<WalkOrder::PreOrder>([&](Operation *op) {
1835
+ if (op == *this )
1836
+ return WalkResult::advance ();
1837
+
1838
+ // Ignore operations of other dialects or omp operations with no regions,
1839
+ // because these will only be checked if they are siblings of an omp
1840
+ // operation that can potentially be captured.
1841
+ bool isOmpDialect = op->getDialect () == ompDialect;
1842
+ bool hasRegions = op->getNumRegions () > 0 ;
1843
+ if (!isOmpDialect || !hasRegions)
1844
+ return WalkResult::skip ();
1845
+
1846
+ // This operation cannot be captured if it can be executed more than once
1847
+ // (i.e. its block's successors can reach it) or if it's not guaranteed to
1848
+ // be executed before all exits of the region (i.e. it doesn't dominate all
1849
+ // blocks with no successors reachable from the entry block).
1850
+ Region *parentRegion = op->getParentRegion ();
1851
+ Block *parentBlock = op->getBlock ();
1852
+
1853
+ for (Block *successor : parentBlock->getSuccessors ())
1854
+ if (successor->isReachable (parentBlock))
1855
+ return WalkResult::interrupt ();
1856
+
1857
+ for (Block &block : *parentRegion)
1858
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1859
+ !domInfo.dominates (parentBlock, &block))
1860
+ return WalkResult::interrupt ();
1861
+
1862
+ // Don't capture this op if it has a not-allowed sibling, and stop recursing
1863
+ // into nested operations.
1864
+ for (Operation &sibling : op->getParentRegion ()->getOps ())
1865
+ if (&sibling != op && !siblingAllowedInCapture (&sibling))
1866
+ return WalkResult::interrupt ();
1867
+
1868
+ // Don't continue capturing nested operations if we reach an omp.loop_nest.
1869
+ // Otherwise, process the contents of this operation.
1870
+ capturedOp = op;
1871
+ return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt ()
1872
+ : WalkResult::advance ();
1873
+ });
1874
+
1875
+ return capturedOp;
1876
+ }
1877
+
1878
+ llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags () {
1879
+ using namespace llvm ::omp;
1880
+
1881
+ // Make sure this region is capturing a loop. Otherwise, it's a generic
1882
+ // kernel.
1883
+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1884
+ if (!isa_and_present<LoopNestOp>(capturedOp))
1885
+ return OMP_TGT_EXEC_MODE_GENERIC;
1886
+
1887
+ SmallVector<LoopWrapperInterface> wrappers;
1888
+ cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
1889
+ assert (!wrappers.empty ());
1890
+
1891
+ // Ignore optional SIMD leaf construct.
1892
+ auto *innermostWrapper = wrappers.begin ();
1893
+ if (isa<SimdOp>(innermostWrapper))
1894
+ innermostWrapper = std::next (innermostWrapper);
1895
+
1896
+ long numWrappers = std::distance (innermostWrapper, wrappers.end ());
1897
+
1898
+ // Detect Generic-SPMD: target-teams-distribute[-simd].
1899
+ if (numWrappers == 1 ) {
1900
+ if (!isa<DistributeOp>(innermostWrapper))
1901
+ return OMP_TGT_EXEC_MODE_GENERIC;
1902
+
1903
+ Operation *teamsOp = (*innermostWrapper)->getParentOp ();
1904
+ if (!isa_and_present<TeamsOp>(teamsOp))
1905
+ return OMP_TGT_EXEC_MODE_GENERIC;
1906
+
1907
+ if (teamsOp->getParentOp () == *this )
1908
+ return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1909
+ }
1910
+
1911
+ // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1912
+ if (numWrappers == 2 ) {
1913
+ if (!isa<WsloopOp>(innermostWrapper))
1914
+ return OMP_TGT_EXEC_MODE_GENERIC;
1915
+
1916
+ innermostWrapper = std::next (innermostWrapper);
1917
+ if (!isa<DistributeOp>(innermostWrapper))
1918
+ return OMP_TGT_EXEC_MODE_GENERIC;
1919
+
1920
+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
1921
+ if (!isa_and_present<ParallelOp>(parallelOp))
1922
+ return OMP_TGT_EXEC_MODE_GENERIC;
1923
+
1924
+ Operation *teamsOp = parallelOp->getParentOp ();
1925
+ if (!isa_and_present<TeamsOp>(teamsOp))
1926
+ return OMP_TGT_EXEC_MODE_GENERIC;
1927
+
1928
+ if (teamsOp->getParentOp () == *this )
1929
+ return OMP_TGT_EXEC_MODE_SPMD;
1930
+ }
1931
+
1932
+ return OMP_TGT_EXEC_MODE_GENERIC;
1933
+ }
1934
+
1745
1935
// ===----------------------------------------------------------------------===//
1746
1936
// ParallelOp
1747
1937
// ===----------------------------------------------------------------------===//
0 commit comments