-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[MLIR][NVVM] Update TMA tensor prefetch Op #153464
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -99,10 +99,41 @@ LogicalResult CpAsyncOp::verify() { | |
} | ||
|
||
LogicalResult CpAsyncBulkTensorPrefetchOp::verify() { | ||
size_t numIm2ColOffsets = getIm2colOffsets().size(); | ||
bool isIm2Col = numIm2ColOffsets > 0; | ||
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col, | ||
numIm2ColOffsets, getLoc()); | ||
size_t tensorDims = getCoordinates().size(); | ||
if (tensorDims < 1 || tensorDims > 5) | ||
return emitError("expects coordinates between 1 to 5 dimension"); | ||
|
||
auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col, | ||
size_t expectedNumOffsets) -> LogicalResult { | ||
if (isIm2col && (tensorDims < 3)) | ||
return emitError() | ||
<< "to use " << stringifyEnum(mode) | ||
<< " mode, the tensor has to be at least 3-dimensional"; | ||
|
||
if (getIm2colOffsets().size() != expectedNumOffsets) | ||
return emitError() << " im2col offsets expected " << expectedNumOffsets | ||
<< " (provided " << getIm2colOffsets().size() << ")"; | ||
|
||
return success(); | ||
}; | ||
|
||
auto mode = getMode(); | ||
switch (mode) { | ||
case TMALoadMode::TILE: | ||
return checkTMALoadParams(mode, false, 0); | ||
case TMALoadMode::IM2COL: | ||
return checkTMALoadParams(mode, true, tensorDims - 2); | ||
case TMALoadMode::IM2COL_W: | ||
case TMALoadMode::IM2COL_W_128: | ||
return checkTMALoadParams(mode, true, 2); | ||
case TMALoadMode::TILE_GATHER4: | ||
return (tensorDims == 5) ? checkTMALoadParams(mode, false, 0) | ||
: emitError("Gather4 mode expects 5 coordinates"); | ||
default: | ||
return emitError("Invalid LoadMode in CpAsyncBulkTensorPrefetchOp."); | ||
} | ||
Comment on lines
+106
to
+134
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the verifier specific to this op? can it be used in general? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The intention is to make it common for both the TMA Load and TMA Prefetch Ops. |
||
|
||
return success(); | ||
} | ||
|
||
LogicalResult CpAsyncBulkTensorReduceOp::verify() { | ||
|
@@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs( | |
return {id, std::move(args)}; | ||
} | ||
|
||
llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims, | ||
bool isIm2Col) { | ||
switch (tensorDims) { | ||
case 1: | ||
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d; | ||
case 2: | ||
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d; | ||
case 3: | ||
return isIm2Col | ||
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d | ||
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d; | ||
case 4: | ||
return isIm2Col | ||
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d | ||
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d; | ||
case 5: | ||
return isIm2Col | ||
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d | ||
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d; | ||
default: | ||
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp."); | ||
} | ||
#define GET_TMA_OPCODE(op, mode, dim) \ | ||
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d | ||
|
||
mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs( | ||
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) { | ||
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op); | ||
llvm::SmallVector<llvm::Value *> args; | ||
|
||
// Fill the Intrinsic Args | ||
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor())); | ||
|
||
for (auto v : thisOp.getCoordinates()) | ||
args.push_back(mt.lookupValue(v)); | ||
for (auto v : thisOp.getIm2colOffsets()) | ||
args.push_back(mt.lookupValue(v)); | ||
|
||
mlir::Value cacheHint = thisOp.getL2CacheHint(); | ||
const bool hasCacheHint = static_cast<bool>(cacheHint); | ||
llvm::Value *i64Unused = | ||
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0); | ||
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused); | ||
args.push_back(builder.getInt1(hasCacheHint)); | ||
|
||
#define NI llvm::Intrinsic::not_intrinsic | ||
#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D) | ||
#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D) | ||
#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D) | ||
#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D) | ||
#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D) | ||
|
||
static constexpr llvm::Intrinsic::ID IDTable[][6] = { | ||
{NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile | ||
{NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col | ||
{NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w | ||
{NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128 | ||
{NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4 | ||
}; | ||
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1, | ||
"TMALoadModes must match number of rows in IDTable"); | ||
|
||
size_t mode = static_cast<size_t>(thisOp.getMode()); | ||
size_t dim = thisOp.getCoordinates().size(); | ||
llvm::Intrinsic::ID id = IDTable[mode][dim]; | ||
if (id == llvm::Intrinsic::not_intrinsic) | ||
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp."); | ||
|
||
return {id, std::move(args)}; | ||
|
||
#undef GATHER4 | ||
#undef IM2COLW128 | ||
#undef IM2COLW | ||
#undef IM2COL | ||
#undef TILE | ||
#undef NI | ||
Comment on lines
+1433
to
+1486
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not use MACROs. It's quite hard to read code with them, and they are error-prones. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can create constexpr tables and use them, they are zero-cost as well. I drafted an example below:
|
||
} | ||
|
||
#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we are implementing the load mode:
tile::
andim2col::
are implemented as main modes, whilegather4
andw
are their sub-details.Is it possible to implement this in the NVVM dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possible and that is what we did first internally. But the verifier logic is becoming harder to maintain as the modes evolve in their own ways. Hence, we resort to separate modes (like we have done for the TMAStore modes).