Skip to content

Commit b4b819c

Browse files
authored
[MLIR][NVVM] Add Op for TMA Store with reduction (#118853)
PR #116854 adds intrinsics for TMA Store with reduction. This patch adds an NVVM Dialect Op for the same. * Lit tests are added to verify the lowering to LLVM intrinsics and invalid cases. * The common verifier method is updated to handle im2col modes without offsets. This helps Ops like TMA Store, TMA StoreReduce etc. * The nvvmir.mlir test file is already large. So, this patch adds the tests for this Op in a new file under a separate "nvvm/" directory. [mlir/test/Target/LLVMIR/"nvvm"/tma_store_reduce.mlir] PTX Spec reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent 22780f8 commit b4b819c

File tree

4 files changed

+503
-9
lines changed

4 files changed

+503
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,6 +2029,107 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
20292029
}];
20302030
}
20312031

2032+
// List of modes supported for TMA Store and Reduction Ops
2033+
def TMAStoreModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
2034+
def TMAStoreModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
2035+
2036+
def TMAStoreMode : I32EnumAttr<"TMAStoreMode", "NVVM TMA Store Mode",
2037+
[TMAStoreModeTile, TMAStoreModeIm2Col]> {
2038+
let genSpecializedAttr = 0;
2039+
let cppNamespace = "::mlir::NVVM";
2040+
}
2041+
def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
2042+
let assemblyFormat = "`<` $value `>`";
2043+
}
2044+
2045+
// List of Reduction Ops supported with TMA Store
2046+
def TMAReduxKindAdd : I32EnumAttrCase<"ADD", 0, "add">;
2047+
def TMAReduxKindMin : I32EnumAttrCase<"MIN", 1, "min">;
2048+
def TMAReduxKindMax : I32EnumAttrCase<"MAX", 2, "max">;
2049+
def TMAReduxKindInc : I32EnumAttrCase<"INC", 3, "inc">;
2050+
def TMAReduxKindDec : I32EnumAttrCase<"DEC", 4, "dec">;
2051+
def TMAReduxKindAnd : I32EnumAttrCase<"AND", 5, "and">;
2052+
def TMAReduxKindOr : I32EnumAttrCase<"OR", 6, "or">;
2053+
def TMAReduxKindXor : I32EnumAttrCase<"XOR", 7, "xor">;
2054+
2055+
def TMAReduxKind : I32EnumAttr<"TMAReduxKind", "NVVM TMA redux kind",
2056+
[TMAReduxKindAdd, TMAReduxKindMax, TMAReduxKindMin,
2057+
TMAReduxKindInc, TMAReduxKindDec, TMAReduxKindAnd,
2058+
TMAReduxKindOr, TMAReduxKindXor]> {
2059+
let genSpecializedAttr = 0;
2060+
let cppNamespace = "::mlir::NVVM";
2061+
}
2062+
def TMAReduxKindAttr : EnumAttr<NVVM_Dialect, TMAReduxKind, "tma_redux_kind"> {
2063+
let assemblyFormat = "`<` $value `>`";
2064+
}
2065+
2066+
def NVVM_CpAsyncBulkTensorReduceOp :
2067+
NVVM_Op<"cp.async.bulk.tensor.reduce", [AttrSizedOperandSegments]> {
2068+
let arguments = (ins
2069+
LLVM_AnyPointer:$tmaDescriptor,
2070+
LLVM_PointerShared:$srcMem,
2071+
TMAReduxKindAttr:$redKind,
2072+
DefaultValuedAttr<TMAStoreModeAttr, "TMAStoreMode::TILE">:$mode,
2073+
Variadic<I32>:$coordinates,
2074+
Optional<I64>:$l2CacheHint);
2075+
2076+
let description = [{
2077+
Initiates an asynchronous reduction operation of tensor data in
2078+
global memory with tensor data in shared memory.
2079+
2080+
The `mode` attribute indicates whether the copy mode is tile or im2col.
2081+
The `redOp` attribute specifies the reduction operations applied.
2082+
The supported reduction operations are:
2083+
{add, min, max, inc, dec, and, or, xor}
2084+
2085+
The `l2CacheHint` operand is optional, and it is used to specify cache
2086+
eviction policy that may be used during the memory access.
2087+
2088+
[For more information, see PTX ISA]
2089+
(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-reduce-async-bulk-tensor)
2090+
}];
2091+
2092+
let assemblyFormat = [{
2093+
$tmaDescriptor `,`
2094+
$srcMem `,`
2095+
`box` `[`$coordinates `]`
2096+
(`l2_cache_hint` `=` $l2CacheHint^ )?
2097+
attr-dict `:` type($tmaDescriptor) `,` type($srcMem)
2098+
}];
2099+
2100+
let extraClassDeclaration = [{
2101+
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims,
2102+
NVVM::TMAReduxKind kind,
2103+
bool isIm2Col);
2104+
}];
2105+
2106+
let hasVerifier = 1;
2107+
2108+
string llvmBuilder = [{
2109+
// Arguments to the intrinsic:
2110+
// shared_mem_ptr, tmaDesc, tensorDims
2111+
// cache_hint(if applicable) and flag(boolean)
2112+
llvm::SmallVector<llvm::Value *> translatedOperands;
2113+
translatedOperands.push_back($srcMem);
2114+
translatedOperands.push_back($tmaDescriptor);
2115+
2116+
for (auto v : op.getCoordinates())
2117+
translatedOperands.push_back(moduleTranslation.lookupValue(v));
2118+
2119+
llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
2120+
auto *i64Undef = llvm::UndefValue::get(llvm::IntegerType::get(ctx, 64));
2121+
2122+
bool isCacheHint = op.getL2CacheHint() ? true : false;
2123+
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Undef);
2124+
translatedOperands.push_back(builder.getInt1(isCacheHint));
2125+
2126+
auto intId = NVVM::CpAsyncBulkTensorReduceOp::getIntrinsicID(
2127+
op.getCoordinates().size(), $redKind,
2128+
(op.getMode() == NVVM::TMAStoreMode::IM2COL));
2129+
createIntrinsicCall(builder, intId, translatedOperands);
2130+
}];
2131+
}
2132+
20322133
//===----------------------------------------------------------------------===//
20332134
// NVVM Wgmma Ops
20342135
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 73 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,30 +75,37 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7575

7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

78-
// This verifier is shared across:
79-
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load) and
80-
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch) Ops.
78+
// This verifier is shared among the following Ops:
79+
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
80+
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
81+
// CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
8182
static LogicalResult CpAsyncBulkTensorCommonVerifier(size_t tensorDims,
83+
bool isIm2Col,
8284
size_t numIm2ColOffsets,
8385
Location loc) {
8486
if (tensorDims < 1 || tensorDims > 5)
8587
return emitError(loc, "expects coordinates between 1 to 5 dimension");
8688

87-
if (numIm2ColOffsets) {
89+
// For Im2Col mode, there are two constraints:
90+
if (isIm2Col) {
91+
// 1. Tensor must always be at least 3-d.
8892
if (tensorDims < 3)
8993
return emitError(
9094
loc,
9195
"to use im2col mode, the tensor has to be at least 3-dimensional");
92-
if (tensorDims != (numIm2ColOffsets + 2))
96+
// 2. When there are Im2ColOffsets, they must be (Dims - 2) in number.
97+
if (numIm2ColOffsets && (tensorDims != (numIm2ColOffsets + 2)))
9398
return emitError(
9499
loc, "im2col offsets must be 2 less than number of coordinates");
95100
}
96101
return success();
97102
}
98103

99104
LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
100-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
101-
getIm2colOffsets().size(), getLoc());
105+
size_t numIm2ColOffsets = getIm2colOffsets().size();
106+
bool isIm2Col = numIm2ColOffsets > 0;
107+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
108+
numIm2ColOffsets, getLoc());
102109
}
103110

104111
LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
@@ -119,8 +126,16 @@ LogicalResult CpAsyncOp::verify() {
119126
}
120127

121128
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
122-
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(),
123-
getIm2colOffsets().size(), getLoc());
129+
size_t numIm2ColOffsets = getIm2colOffsets().size();
130+
bool isIm2Col = numIm2ColOffsets > 0;
131+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
132+
numIm2ColOffsets, getLoc());
133+
}
134+
135+
LogicalResult CpAsyncBulkTensorReduceOp::verify() {
136+
bool isIm2Col = (getMode() == TMAStoreMode::IM2COL);
137+
return CpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, 0,
138+
getLoc());
124139
}
125140

126141
// Given the element type of an operand and whether or not it is an accumulator,
@@ -1094,6 +1109,55 @@ llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
10941109
}
10951110
}
10961111

1112+
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
1113+
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d
1114+
1115+
#define CP_ASYNC_BULK_TENSOR_REDUCE(op, dim, is_im2col) \
1116+
is_im2col ? CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, im2col) \
1117+
: CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, tile)
1118+
1119+
#define GET_CP_ASYNC_BULK_TENSOR_ID(op, dims, is_im2col) \
1120+
[&]() -> auto { \
1121+
switch (dims) { \
1122+
case 1: \
1123+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 1, tile); \
1124+
case 2: \
1125+
return CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, 2, tile); \
1126+
case 3: \
1127+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 3, is_im2col); \
1128+
case 4: \
1129+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 4, is_im2col); \
1130+
case 5: \
1131+
return CP_ASYNC_BULK_TENSOR_REDUCE(op, 5, is_im2col); \
1132+
default: \
1133+
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorReduceOp."); \
1134+
} \
1135+
}()
1136+
1137+
llvm::Intrinsic::ID CpAsyncBulkTensorReduceOp::getIntrinsicID(
1138+
int tensorDims, NVVM::TMAReduxKind kind, bool isIm2Col) {
1139+
using RedTy = NVVM::TMAReduxKind;
1140+
switch (kind) {
1141+
case RedTy::ADD:
1142+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_add, tensorDims, isIm2Col);
1143+
case RedTy::MIN:
1144+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_min, tensorDims, isIm2Col);
1145+
case RedTy::MAX:
1146+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_max, tensorDims, isIm2Col);
1147+
case RedTy::INC:
1148+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_inc, tensorDims, isIm2Col);
1149+
case RedTy::DEC:
1150+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_dec, tensorDims, isIm2Col);
1151+
case RedTy::AND:
1152+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_and, tensorDims, isIm2Col);
1153+
case RedTy::OR:
1154+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_or, tensorDims, isIm2Col);
1155+
case RedTy::XOR:
1156+
return GET_CP_ASYNC_BULK_TENSOR_ID(reduce_xor, tensorDims, isIm2Col);
1157+
}
1158+
llvm_unreachable("Invalid Reduction Op for CpAsyncBulkTensorReduceOp");
1159+
}
1160+
10971161
//===----------------------------------------------------------------------===//
10981162
// NVVMDialect initialization, type parsing, and registration.
10991163
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)