@@ -45,12 +45,14 @@ using namespace NVVM;
4545#include " mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
4646#include " mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
4747
48+ static constexpr unsigned notIntrinsic = llvm::Intrinsic::not_intrinsic;
49+
4850// ===----------------------------------------------------------------------===//
4951// Verifier methods
5052// ===----------------------------------------------------------------------===//
5153
5254// This verifier is shared among the following Ops:
53- // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load )
55+ // CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store )
5456// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
5557static LogicalResult cpAsyncBulkTensorCommonVerifier (size_t tensorDims,
5658 bool isIm2Col,
@@ -74,13 +76,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
7476 return success ();
7577}
7678
77- LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
78- size_t numIm2ColOffsets = getIm2colOffsets ().size ();
79- bool isIm2Col = numIm2ColOffsets > 0 ;
80- return cpAsyncBulkTensorCommonVerifier (getCoordinates ().size (), isIm2Col,
81- numIm2ColOffsets, getLoc ());
82- }
83-
8479LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify () {
8580 TMAStoreMode mode = getMode ();
8681 // We lower through inline-ptx when getPredicate() is true.
@@ -158,6 +153,38 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
158153 getMode (), getLoc ());
159154}
160155
156+ LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify () {
157+ TMALoadMode mode = getMode ();
158+ bool isCTAOnly = getIsCTAOnly ();
159+ if (getPredicate ()) { // Inline-asm based lowering
160+ if (isCTAOnly)
161+ return emitError (" Predicate is supported only for shared::cluster mode." );
162+ if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
163+ return emitError (
164+ " Predicate is supported only for Tile and Im2col modes." );
165+ } else { // Intrinsics-based lowering
166+ NVVMMemorySpace expectedAS =
167+ isCTAOnly ? NVVMMemorySpace::Shared : NVVMMemorySpace::SharedCluster;
168+ unsigned AS = llvm::cast<LLVM::LLVMPointerType>(getDstMem ().getType ())
169+ .getAddressSpace ();
170+ if (AS != expectedAS)
171+ return emitError ()
172+ << (isCTAOnly
173+ ? " Shared::cta destination requires address-space 3."
174+ : " Shared::cluster destination requires address-space 7." );
175+ // Checks specific to shared::cta mode
176+ if (isCTAOnly) {
177+ if (getMulticastMask ())
178+ return emitError (" Multicast is not supported with shared::cta mode." );
179+ if (getGroup ())
180+ return emitError (" CTAGroup is not supported with shared::cta mode." );
181+ }
182+ }
183+
184+ return verifyTMALoadParams (getCoordinates ().size (), getIm2colOffsets ().size (),
185+ getMode (), getLoc ());
186+ }
187+
161188LogicalResult CpAsyncBulkTensorReduceOp::verify () {
162189 TMAStoreMode mode = getMode ();
163190 size_t dims = getCoordinates ().size ();
@@ -1553,6 +1580,130 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
15531580 return {id, std::move (args)};
15541581}
15551582
1583+ bool CpAsyncBulkTensorGlobalToSharedClusterOp::getAsmValues (
1584+ RewriterBase &rewriter,
1585+ llvm::SmallVectorImpl<std::pair<mlir::Value, mlir::NVVM::PTXRegisterMod>>
1586+ &asmValues) {
1587+ // Add all the operands but not the attrs to the asmValues list.
1588+ // The attrs here are used to generate the right variants for
1589+ // intrinsics-lowering. So, we ignore them while generating inline-PTX.
1590+ for (auto val : getOperands ())
1591+ asmValues.push_back ({val, mlir::NVVM::PTXRegisterMod::Read});
1592+
1593+ return false ;
1594+ }
1595+
1596+ mlir::NVVM::IDArgPair
1597+ CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs (
1598+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
1599+ auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
1600+ const bool isCTAOnly = thisOp.getIsCTAOnly ();
1601+ llvm::SmallVector<llvm::Value *> args;
1602+
1603+ // Fill the Intrinsic Args
1604+ args.push_back (mt.lookupValue (thisOp.getDstMem ()));
1605+ args.push_back (mt.lookupValue (thisOp.getMbar ()));
1606+ args.push_back (mt.lookupValue (thisOp.getTmaDescriptor ()));
1607+
1608+ // Coordinates and im2col-offsets
1609+ for (mlir::Value v : thisOp.getCoordinates ())
1610+ args.push_back (mt.lookupValue (v));
1611+ for (mlir::Value v : thisOp.getIm2colOffsets ())
1612+ args.push_back (mt.lookupValue (v));
1613+
1614+ // MulticastMask, if available
1615+ mlir::Value mcMask = thisOp.getMulticastMask ();
1616+ const bool hasMC = static_cast <bool >(mcMask);
1617+ llvm::Value *i16Zero =
1618+ llvm::ConstantInt::get (llvm::Type::getInt16Ty (mt.getLLVMContext ()), 0 );
1619+
1620+ // CacheHint, if available
1621+ mlir::Value cacheHint = thisOp.getL2CacheHint ();
1622+ const bool hasCacheHint = static_cast <bool >(cacheHint);
1623+ llvm::Value *i64Zero =
1624+ llvm::ConstantInt::get (llvm::Type::getInt64Ty (mt.getLLVMContext ()), 0 );
1625+
1626+ // Flag argument CTAGroup
1627+ // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
1628+ // Hence, the +1 to getGroup().
1629+ const int32_t val =
1630+ thisOp.getGroup () ? (static_cast <int32_t >(*thisOp.getGroup ()) + 1 ) : 0 ;
1631+ llvm::Value *cg =
1632+ llvm::ConstantInt::get (llvm::Type::getInt32Ty (mt.getLLVMContext ()), val);
1633+
1634+ if (!isCTAOnly) {
1635+ // For shared::cluster, all the arguments that we build are applicable.
1636+ args.push_back (hasMC ? mt.lookupValue (mcMask) : i16Zero);
1637+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1638+ args.push_back (builder.getInt1 (hasMC));
1639+ args.push_back (builder.getInt1 (hasCacheHint));
1640+ args.push_back (cg);
1641+ } else {
1642+ // For shared::cta, only cache-hint is applicable.
1643+ args.push_back (hasCacheHint ? mt.lookupValue (cacheHint) : i64Zero);
1644+ args.push_back (builder.getInt1 (hasCacheHint));
1645+ }
1646+
1647+ constexpr size_t numDims = 5 ; // 1D to 5D
1648+ constexpr size_t numModes = 5 ; // Tile, Im2col, w, w_128, gather4
1649+ using rowTy = std::array<llvm::Intrinsic::ID, numDims + 1 >;
1650+ using TableTy = std::array<rowTy, numModes>;
1651+ static constexpr TableTy IDTable{
1652+ {{notIntrinsic, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
1653+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
1654+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
1655+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
1656+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
1657+ {notIntrinsic, notIntrinsic, notIntrinsic,
1658+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
1659+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
1660+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
1661+ {notIntrinsic, notIntrinsic, notIntrinsic,
1662+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
1663+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
1664+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
1665+ {notIntrinsic, notIntrinsic, notIntrinsic,
1666+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
1667+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
1668+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
1669+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1670+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}}};
1671+
1672+ static constexpr TableTy IDTableCTA{
1673+ {{notIntrinsic,
1674+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_1d,
1675+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_2d,
1676+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_3d,
1677+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_4d,
1678+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_5d},
1679+ {notIntrinsic, notIntrinsic, notIntrinsic,
1680+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_3d,
1681+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_4d,
1682+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_5d},
1683+ {notIntrinsic, notIntrinsic, notIntrinsic,
1684+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_3d,
1685+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_4d,
1686+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_5d},
1687+ {notIntrinsic, notIntrinsic, notIntrinsic,
1688+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_3d,
1689+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_4d,
1690+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_im2col_w_128_5d},
1691+ {notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic, notIntrinsic,
1692+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_cta_tile_gather4_2d}}};
1693+
1694+ static_assert (
1695+ (getMaxEnumValForTMALoadMode () == std::size (IDTable) - 1 ) &&
1696+ (getMaxEnumValForTMALoadMode () == std::size (IDTableCTA) - 1 ),
1697+ " TMALoadModes must match number of rows in IDTable and IDTableCTA" );
1698+ size_t mode = static_cast <size_t >(thisOp.getMode ());
1699+ size_t dim = thisOp.getCoordinates ().size ();
1700+ auto id = isCTAOnly ? IDTableCTA[mode][dim] : IDTable[mode][dim];
1701+ assert (id != notIntrinsic &&
1702+ " Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp." );
1703+
1704+ return {id, std::move (args)};
1705+ }
1706+
15561707mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs (
15571708 Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
15581709 auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
0 commit comments