-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR][NVVM] Update TMA tensor prefetch Op #153464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This patch updates the TMA Tensor prefetch Op with support for im2col_w/w128 and tile_gather4 modes. This completes support for all modes available in Blackwell. * lit tests are added for all possible combinations. * The invalid tests are moved to a separate file with more coverage. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Durgadoss R (durga4github) ChangesThis patch updates the TMA Tensor prefetch Op
Patch is 30.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/153464.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 8d507268a3a15..272fb74a5fccd 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2253,6 +2253,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//
+// List of modes supported for TMA Load and Prefetch Ops
+def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
+def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
+def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
+def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
+def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
+
+def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
+ [TMALoadModeTile, TMALoadModeIm2Col,
+ TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
+ TMALoadModeTileGather4]> {
+ let genSpecializedAttr = 0;
+ let cppNamespace = "::mlir::NVVM";
+}
+def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
+ let summary = "List of Load-Modes supported for TMA Tensor Ops";
+ let description = [{
+ TMA Tensor Ops support the following modes, when copying data from
+ global memory to shared memory (i.e. load):
+
+ Tile Mode: It's the default mode. The source multi-dimensional tensor
+ layout is preserved at the destination.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)
+
+ Im2col Mode: This mode is used when `im2colOffsets` operands are present.
+ The elements in the Bounding Box of the source tensor are rearranged into
+ columns at the destination. In this mode, the tensor has to be at least
+ 3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
+ is the dimension of the tensor.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)
+
+ Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
+ elements are accessed across the W dimension only. The number of `im2colOffsets`
+ are always two, referred as `wHalo` and `wOffset`.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
+ elements accessed across the W dimension is always 128 only.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)
+
+ Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
+ In gather4 mode, four rows in the source 2D tensor are combined to form a single
+ 2D tensor at the destination. This mode requires five co-ordinates. The first one
+ represents the column-index followed by four row indices.
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
+ }];
+
+ let assemblyFormat = "`<` $value `>`";
+}
+
def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
@@ -2521,23 +2571,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
- LLVM_AnyPointer:$tmaDescriptor,
+ LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
+ DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
Optional<I64>:$l2CacheHint);
let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
- memory to L2 cache.
-
- The Op has two modes:
- 1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
- layout is preserved at the destination.
-
- 2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
- the elements in the Bounding Box of the source tensor are rearranged into
- columns at the destination. In this mode, the tensor has to be at least
- 3-dimensional.
+ memory to L2 cache. This Op supports all the load modes specified in
+ `TMALoadMode`.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
@@ -2554,34 +2597,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}];
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
}];
let hasVerifier = 1;
string llvmBuilder = [{
- // Arguments to the intrinsic:
- // tmaDesc, tensorDims, im2colOffsets
- // cache_hint(if applicable) and flag(boolean)
- llvm::SmallVector<llvm::Value *> translatedOperands;
- translatedOperands.push_back($tmaDescriptor);
-
- for (auto v : op.getCoordinates())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- for (auto v : op.getIm2colOffsets())
- translatedOperands.push_back(moduleTranslation.lookupValue(v));
-
- llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
- auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);
-
- bool isCacheHint = op.getL2CacheHint() ? true : false;
- translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
- translatedOperands.push_back(builder.getInt1(isCacheHint));
-
- auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
- op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
- createIntrinsicCall(builder, intId, translatedOperands);
+ auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 7ad429efc9fad..74dbf42a2df79 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -99,10 +99,41 @@ LogicalResult CpAsyncOp::verify() {
}
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
- size_t numIm2ColOffsets = getIm2colOffsets().size();
- bool isIm2Col = numIm2ColOffsets > 0;
- return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
- numIm2ColOffsets, getLoc());
+ size_t tensorDims = getCoordinates().size();
+ if (tensorDims < 1 || tensorDims > 5)
+ return emitError("expects coordinates between 1 to 5 dimension");
+
+ auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
+ size_t expectedNumOffsets) -> LogicalResult {
+ if (isIm2col && (tensorDims < 3))
+ return emitError()
+ << "to use " << stringifyEnum(mode)
+ << " mode, the tensor has to be at least 3-dimensional";
+
+ if (getIm2colOffsets().size() != expectedNumOffsets)
+ return emitError() << " im2col offsets expected " << expectedNumOffsets
+ << " (provided " << getIm2colOffsets().size() << ")";
+
+ return success();
+ };
+
+ auto mode = getMode();
+ switch (mode) {
+ case TMALoadMode::TILE:
+ return checkTMALoadParams(mode, false, 0);
+ case TMALoadMode::IM2COL:
+ return checkTMALoadParams(mode, true, tensorDims - 2);
+ case TMALoadMode::IM2COL_W:
+ case TMALoadMode::IM2COL_W_128:
+ return checkTMALoadParams(mode, true, 2);
+ case TMALoadMode::TILE_GATHER4:
+ return (tensorDims == 5) ? checkTMALoadParams(mode, false, 0)
+ : emitError("Gather4 mode expects 5 coordinates");
+ default:
+ return emitError("Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
+ }
+
+ return success();
}
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
@@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}
-llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
- bool isIm2Col) {
- switch (tensorDims) {
- case 1:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
- case 2:
- return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
- case 3:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
- case 4:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
- case 5:
- return isIm2Col
- ? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
- : llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
- default:
- llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
- }
+#define GET_TMA_OPCODE(op, mode, dim) \
+ llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
+
+mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
+ llvm::SmallVector<llvm::Value *> args;
+
+ // Fill the Intrinsic Args
+ args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+ for (auto v : thisOp.getCoordinates())
+ args.push_back(mt.lookupValue(v));
+ for (auto v : thisOp.getIm2colOffsets())
+ args.push_back(mt.lookupValue(v));
+
+ mlir::Value cacheHint = thisOp.getL2CacheHint();
+ const bool hasCacheHint = static_cast<bool>(cacheHint);
+ llvm::Value *i64Unused =
+ llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+ args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+ args.push_back(builder.getInt1(hasCacheHint));
+
+#define NI llvm::Intrinsic::not_intrinsic
+#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D)
+#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D)
+#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D)
+#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D)
+#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D)
+
+ static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+ {NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile
+ {NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col
+ {NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w
+ {NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128
+ {NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4
+ };
+ static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+ "TMALoadModes must match number of rows in IDTable");
+
+ size_t mode = static_cast<size_t>(thisOp.getMode());
+ size_t dim = thisOp.getCoordinates().size();
+ llvm::Intrinsic::ID id = IDTable[mode][dim];
+ if (id == llvm::Intrinsic::not_intrinsic)
+ llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
+
+ return {id, std::move(args)};
+
+#undef GATHER4
+#undef IM2COLW128
+#undef IM2COLW
+#undef IM2COL
+#undef TILE
+#undef NI
}
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
diff --git a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
index bfd952636ffbe..536b52b034db8 100644
--- a/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
+++ b/mlir/test/Target/LLVMIR/nvvm/tma_prefetch.mlir
@@ -1,70 +1,123 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
-// CHECK-LABEL: @tma_bulk_prefetch
llvm.func @tma_bulk_prefetch(%src : !llvm.ptr<1>, %size : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_bulk_prefetch(ptr addrspace(1) %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.prefetch.L2(ptr addrspace(1) %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.prefetch %src, %size : !llvm.ptr<1>
nvvm.cp.async.bulk.prefetch %src, %size l2_cache_hint = %ch : !llvm.ptr<1>
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_1d
llvm.func @tma_prefetch_1d(%tma_desc : !llvm.ptr, %d0 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_1d(ptr %0, i32 %1, i64 %2) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.1d(ptr %0, i32 %1, i64 %2, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_2d
llvm.func @tma_prefetch_2d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] : !llvm.ptr
+ // CHECK-LABEL: define void @tma_prefetch_2d(ptr %0, i32 %1, i32 %2, i64 %3) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.2d(ptr %0, i32 %1, i32 %2, i64 %3, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] {mode = #nvvm.tma_load_mode<tile>} : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1] l2_cache_hint = %ch : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_3d
-llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+llvm.func @tma_prefetch_3d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
+ // CHECK-LABEL: define void @tma_prefetch_3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.3d(ptr %0, i32 %1, i32 %2, i32 %3, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.3d(ptr %0, i32 %1, i32 %2, i32 %3, i16 %4, i16 %5, i64 %6, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.3d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i16 %{{.*}}, i64 %{{.*}}, i1 true)
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] : !llvm.ptr
- nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w>} : !llvm.ptr
+
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
+ nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2] im2col[%off0, %off1] l2_cache_hint = %ch {mode = #nvvm.tma_load_mode<im2col_w_128>} : !llvm.ptr
llvm.return
}
-// CHECK-LABEL: @tma_prefetch_4d
llvm.func @tma_prefetch_4d(%tma_desc : !llvm.ptr, %d0 : i32, %d1 : i32, %d2 : i32, %d3 : i32, %off0 : i16, %off1 : i16, %ch : i64) {
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 0, i1 false)
- // CHECK: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}, i64 %{{.*}}, i1 true)
+ // CHECK-LABEL: define void @tma_prefetch_4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7) {
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.tile.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 0, i1 false)
+ // CHECK-NEXT: call void @llvm.nvvm.cp.async.bulk.tensor.prefetch.im2col.w.128.4d(ptr %0, i32 %1, i32 %2, i32 %3, i32 %4, i16 %5, i16 %6, i64 %7, i1 true)
+ // CHECK-NEXT: ret void
+ // CHECK-NEXT: }
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] : !llvm.ptr
nvvm.cp.async.bulk.tensor.prefetch %tma_desc, box[%d0, %d1, %d2, %d3] l2_cache_hint = %ch : !llvm.ptr
- // CHECK: call void @llvm.nvvm...
[truncated]
|
def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">; | ||
def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">; | ||
def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">; | ||
def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">; | ||
def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we are implementing the load mode:
.load_mode = { .tile, .tile::gather4, .im2col, .im2col::w, .im2col::w::128 }
tile::
and im2col::
are implemented as main modes, while gather4
and w
are their sub-details.
Is it possible to implement this in the NVVM dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible and that is what we did first internally. But the verifier logic is becoming harder to maintain as the modes evolve in their own ways. Hence, we resort to separate modes (like we have done for the TMAStore modes).
auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col, | ||
size_t expectedNumOffsets) -> LogicalResult { | ||
if (isIm2col && (tensorDims < 3)) | ||
return emitError() | ||
<< "to use " << stringifyEnum(mode) | ||
<< " mode, the tensor has to be at least 3-dimensional"; | ||
|
||
if (getIm2colOffsets().size() != expectedNumOffsets) | ||
return emitError() << " im2col offsets expected " << expectedNumOffsets | ||
<< " (provided " << getIm2colOffsets().size() << ")"; | ||
|
||
return success(); | ||
}; | ||
|
||
auto mode = getMode(); | ||
switch (mode) { | ||
case TMALoadMode::TILE: | ||
return checkTMALoadParams(mode, false, 0); | ||
case TMALoadMode::IM2COL: | ||
return checkTMALoadParams(mode, true, tensorDims - 2); | ||
case TMALoadMode::IM2COL_W: | ||
case TMALoadMode::IM2COL_W_128: | ||
return checkTMALoadParams(mode, true, 2); | ||
case TMALoadMode::TILE_GATHER4: | ||
return (tensorDims == 5) ? checkTMALoadParams(mode, false, 0) | ||
: emitError("Gather4 mode expects 5 coordinates"); | ||
default: | ||
return emitError("Invalid LoadMode in CpAsyncBulkTensorPrefetchOp."); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the verifier specific to this op? can it be used in general?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The intention is to make it common for both the TMA Load and TMA Prefetch Ops.
For now, I have kept it here. I plan to make it common with my TMA Load Op changes. (since there will be more than one usages of this function then).
#define GET_TMA_OPCODE(op, mode, dim) \ | ||
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d | ||
|
||
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( | ||
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { | ||
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op); | ||
llvm::SmallVector<llvm::Value *> args; | ||
|
||
// Fill the Intrinsic Args | ||
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); | ||
|
||
for (auto v : thisOp.getCoordinates()) | ||
args.push_back(mt.lookupValue(v)); | ||
for (auto v : thisOp.getIm2colOffsets()) | ||
args.push_back(mt.lookupValue(v)); | ||
|
||
mlir::Value cacheHint = thisOp.getL2CacheHint(); | ||
const bool hasCacheHint = static_cast<bool>(cacheHint); | ||
llvm::Value *i64Unused = | ||
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); | ||
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); | ||
args.push_back(builder.getInt1(hasCacheHint)); | ||
|
||
#define NI llvm::Intrinsic::not_intrinsic | ||
#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D) | ||
#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D) | ||
#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D) | ||
#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D) | ||
#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D) | ||
|
||
static constexpr llvm::Intrinsic::ID IDTable[][6] = { | ||
{NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile | ||
{NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col | ||
{NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w | ||
{NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128 | ||
{NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4 | ||
}; | ||
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1, | ||
"TMALoadModes must match number of rows in IDTable"); | ||
|
||
size_t mode = static_cast<size_t>(thisOp.getMode()); | ||
size_t dim = thisOp.getCoordinates().size(); | ||
llvm::Intrinsic::ID id = IDTable[mode][dim]; | ||
if (id == llvm::Intrinsic::not_intrinsic) | ||
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); | ||
|
||
return {id, std::move(args)}; | ||
|
||
#undef GATHER4 | ||
#undef IM2COLW128 | ||
#undef IM2COLW | ||
#undef IM2COL | ||
#undef TILE | ||
#undef NI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's not use MACROs. It's quite hard to read code with them, and they are error-prones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can create constexpr tables and use them, they are zero-cost as well. I drafted an example below:
constexpr llvm::Intrinsic::ID NI = llvm::Intrinsic::not_intrinsic;
static constexpr llvm::Intrinsic::ID IDTable[][6] = {
// tile
{
NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d,
},
// im2col
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d,
},
// im2col_w
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d,
},
// im2col_w_128
{
NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d,
},
// tile_gather4
{
NI, NI, NI, NI, NI,
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d,
},
};
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
"TMALoadModes must match number of rows in IDTable");
const size_t mode = static_cast<size_t>(thisOp.getMode());
const size_t dim = thisOp.getCoordinates().size();
if (mode >= std::size(IDTable) || dim >= std::size(IDTable[0]))
llvm_unreachable("Mode or dimension out of range for CpAsyncBulkTensorPrefetchOp.");
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");
This patch updates the TMA Tensor prefetch Op
to add support for im2col_w/w128 and tile_gather4 modes.
This completes support for all modes available in Blackwell.