Skip to content

[MLIR][NVVM] Update TMA tensor prefetch Op #153464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 60 additions & 34 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2253,6 +2253,56 @@ def NVVM_MmaOp : NVVM_Op<"mma.sync", [AttrSizedOperandSegments]> {
// NVVM TMA Ops
//===----------------------------------------------------------------------===//

// List of modes supported for TMA Load and Prefetch Ops
def TMALoadModeTile : I32EnumAttrCase<"TILE", 0, "tile">;
def TMALoadModeIm2Col : I32EnumAttrCase<"IM2COL", 1, "im2col">;
def TMALoadModeIm2ColW : I32EnumAttrCase<"IM2COL_W", 2, "im2col_w">;
def TMALoadModeIm2ColW128 : I32EnumAttrCase<"IM2COL_W_128", 3, "im2col_w_128">;
def TMALoadModeTileGather4 : I32EnumAttrCase<"TILE_GATHER4", 4, "tile_gather4">;
Comment on lines +2257 to +2261
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we are implementing the load mode:

.load_mode = { .tile, .tile::gather4, .im2col, .im2col::w, .im2col::w::128 }

tile:: and im2col:: are implemented as main modes, while gather4 and w are their sub-details.

Is it possible to implement this in the NVVM dialect?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible and that is what we did first internally. But the verifier logic is becoming harder to maintain as the modes evolve in their own ways. Hence, we resort to separate modes (like we have done for the TMAStore modes).


def TMALoadMode : I32EnumAttr<"TMALoadMode", "NVVM TMA Load Mode",
[TMALoadModeTile, TMALoadModeIm2Col,
TMALoadModeIm2ColW, TMALoadModeIm2ColW128,
TMALoadModeTileGather4]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def TMALoadModeAttr : EnumAttr<NVVM_Dialect, TMALoadMode, "tma_load_mode"> {
let summary = "List of Load-Modes supported for TMA Tensor Ops";
let description = [{
TMA Tensor Ops support the following modes, when copying data from
global memory to shared memory (i.e. load):

Tile Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-mode)

Im2col Mode: This mode is used when `im2colOffsets` operands are present.
The elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional. The number of `im2colOffsets` is `dims - 2` where `dims`
is the dimension of the tensor.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-mode)

Im2col_W Mode: This mode is similar to Im2Col mode with the restriction that
elements are accessed across the W dimension only. The number of `im2colOffsets`
are always two, referred as `wHalo` and `wOffset`.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)

Im2col_W_128 Mode: This mode is similar to Im2Col_W mode with the number of
elements accessed across the W dimension is always 128 only.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-im2col-w-w128-modes)

Tile_Gather4 Mode: This mode is similar to Tile mode but works only on 2D tensor.
In gather4 mode, four rows in the source 2D tensor are combined to form a single
2D tensor at the destination. This mode requires five co-ordinates. The first one
represents the column-index followed by four row indices.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensor-tiled-scatter4-gather4-modes)
}];

let assemblyFormat = "`<` $value `>`";
}

def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
Arguments<(ins )> {
let assemblyFormat = "attr-dict";
Expand Down Expand Up @@ -2521,23 +2571,16 @@ def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {
def NVVM_CpAsyncBulkTensorPrefetchOp :
NVVM_Op<"cp.async.bulk.tensor.prefetch", [AttrSizedOperandSegments]> {
let arguments = (ins
LLVM_AnyPointer:$tmaDescriptor,
LLVM_PointerGeneric:$tmaDescriptor,
Variadic<I32>:$coordinates,
Variadic<I16>:$im2colOffsets,
DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
Optional<I64>:$l2CacheHint);

let description = [{
Initiates an asynchronous prefetch operation on the tensor data from global
memory to L2 cache.

The Op has two modes:
1) Tiled Mode: It's the default mode. The source multi-dimensional tensor
layout is preserved at the destination.

2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.
memory to L2 cache. This Op supports all the load modes specified in
`TMALoadMode`.

The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.
Expand All @@ -2554,34 +2597,17 @@ def NVVM_CpAsyncBulkTensorPrefetchOp :
}];

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(int tensorDims, bool isIm2Col);
static mlir::NVVM::IDArgPair
getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
llvm::IRBuilderBase& builder);
}];

let hasVerifier = 1;

string llvmBuilder = [{
// Arguments to the intrinsic:
// tmaDesc, tensorDims, im2colOffsets
// cache_hint(if applicable) and flag(boolean)
llvm::SmallVector<llvm::Value *> translatedOperands;
translatedOperands.push_back($tmaDescriptor);

for (auto v : op.getCoordinates())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

for (auto v : op.getIm2colOffsets())
translatedOperands.push_back(moduleTranslation.lookupValue(v));

llvm::LLVMContext &ctx = moduleTranslation.getLLVMContext();
auto *i64Unused = llvm::ConstantInt::get(llvm::Type::getInt64Ty(ctx), 0);

bool isCacheHint = op.getL2CacheHint() ? true : false;
translatedOperands.push_back(isCacheHint ? $l2CacheHint : i64Unused);
translatedOperands.push_back(builder.getInt1(isCacheHint));

auto intId = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicID(
op.getCoordinates().size(), op.getIm2colOffsets().size() > 0);
createIntrinsicCall(builder, intId, translatedOperands);
auto [id, args] = NVVM::CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
*op, moduleTranslation, builder);
createIntrinsicCall(builder, id, args);
}];
}

Expand Down
115 changes: 89 additions & 26 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,41 @@ LogicalResult CpAsyncOp::verify() {
}

LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
size_t numIm2ColOffsets = getIm2colOffsets().size();
bool isIm2Col = numIm2ColOffsets > 0;
return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
numIm2ColOffsets, getLoc());
size_t tensorDims = getCoordinates().size();
if (tensorDims < 1 || tensorDims > 5)
return emitError("expects coordinates between 1 to 5 dimension");

auto checkTMALoadParams = [&](TMALoadMode mode, bool isIm2col,
size_t expectedNumOffsets) -> LogicalResult {
if (isIm2col && (tensorDims < 3))
return emitError()
<< "to use " << stringifyEnum(mode)
<< " mode, the tensor has to be at least 3-dimensional";

if (getIm2colOffsets().size() != expectedNumOffsets)
return emitError() << " im2col offsets expected " << expectedNumOffsets
<< " (provided " << getIm2colOffsets().size() << ")";

return success();
};

auto mode = getMode();
switch (mode) {
case TMALoadMode::TILE:
return checkTMALoadParams(mode, false, 0);
case TMALoadMode::IM2COL:
return checkTMALoadParams(mode, true, tensorDims - 2);
case TMALoadMode::IM2COL_W:
case TMALoadMode::IM2COL_W_128:
return checkTMALoadParams(mode, true, 2);
case TMALoadMode::TILE_GATHER4:
return (tensorDims == 5) ? checkTMALoadParams(mode, false, 0)
: emitError("Gather4 mode expects 5 coordinates");
default:
return emitError("Invalid LoadMode in CpAsyncBulkTensorPrefetchOp.");
}
Comment on lines +106 to +134
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the verifier specific to this op? can it be used in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intention is to make it common for both the TMA Load and TMA Prefetch Ops.
For now, I have kept it here. I plan to make it common with my TMA Load Op changes. (since there will be more than one usages of this function then).


return success();
}

LogicalResult CpAsyncBulkTensorReduceOp::verify() {
Expand Down Expand Up @@ -1399,28 +1430,60 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
return {id, std::move(args)};
}

llvm::Intrinsic::ID CpAsyncBulkTensorPrefetchOp::getIntrinsicID(int tensorDims,
bool isIm2Col) {
switch (tensorDims) {
case 1:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d;
case 2:
return llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d;
case 3:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d;
case 4:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d;
case 5:
return isIm2Col
? llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d
: llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d;
default:
llvm_unreachable("Invalid TensorDim in CpAsyncBulkTensorPrefetchOp.");
}
#define GET_TMA_OPCODE(op, mode, dim) \
llvm::Intrinsic::nvvm_cp_async_bulk_tensor_##op##_##mode##_##dim##d

mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
llvm::SmallVector<llvm::Value *> args;

// Fill the Intrinsic Args
args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));

for (auto v : thisOp.getCoordinates())
args.push_back(mt.lookupValue(v));
for (auto v : thisOp.getIm2colOffsets())
args.push_back(mt.lookupValue(v));

mlir::Value cacheHint = thisOp.getL2CacheHint();
const bool hasCacheHint = static_cast<bool>(cacheHint);
llvm::Value *i64Unused =
llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
args.push_back(builder.getInt1(hasCacheHint));

#define NI llvm::Intrinsic::not_intrinsic
#define TILE(D) GET_TMA_OPCODE(prefetch, tile, D)
#define IM2COL(D) GET_TMA_OPCODE(prefetch, im2col, D)
#define IM2COLW(D) GET_TMA_OPCODE(prefetch, im2col_w, D)
#define IM2COLW128(D) GET_TMA_OPCODE(prefetch, im2col_w_128, D)
#define GATHER4(D) GET_TMA_OPCODE(prefetch, tile_gather4, D)

static constexpr llvm::Intrinsic::ID IDTable[][6] = {
{NI, TILE(1), TILE(2), TILE(3), TILE(4), TILE(5)}, // tile
{NI, NI, NI, IM2COL(3), IM2COL(4), IM2COL(5)}, // im2col
{NI, NI, NI, IM2COLW(3), IM2COLW(4), IM2COLW(5)}, // im2col_w
{NI, NI, NI, IM2COLW128(3), IM2COLW128(4), IM2COLW128(5)}, // im2col_w128
{NI, NI, NI, NI, NI, GATHER4(2)}, // tile_gather4
};
static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
"TMALoadModes must match number of rows in IDTable");

size_t mode = static_cast<size_t>(thisOp.getMode());
size_t dim = thisOp.getCoordinates().size();
llvm::Intrinsic::ID id = IDTable[mode][dim];
if (id == llvm::Intrinsic::not_intrinsic)
llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");

return {id, std::move(args)};

#undef GATHER4
#undef IM2COLW128
#undef IM2COLW
#undef IM2COL
#undef TILE
#undef NI
Comment on lines +1433 to +1486
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use MACROs. It's quite hard to read code with them, and they are error-prones.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can create constexpr tables and use them, they are zero-cost as well. I drafted an example below:

  constexpr llvm::Intrinsic::ID NI = llvm::Intrinsic::not_intrinsic;
  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
      // tile
      {
          NI,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_1d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_2d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_3d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_4d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_5d,
      },
      // im2col
      {
          NI, NI, NI,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_3d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_4d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_5d,
      },
      // im2col_w
      {
          NI, NI, NI,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_3d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_4d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_5d,
      },
      // im2col_w_128
      {
          NI, NI, NI,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_3d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_4d,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_im2col_w_128_5d,
      },
      // tile_gather4
      {
          NI, NI, NI, NI, NI,
          llvm::Intrinsic::nvvm_cp_async_bulk_tensor_prefetch_tile_gather4_2d,
      },
  };

  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
                "TMALoadModes must match number of rows in IDTable");

  const size_t mode = static_cast<size_t>(thisOp.getMode());
  const size_t dim  = thisOp.getCoordinates().size();
  if (mode >= std::size(IDTable) || dim >= std::size(IDTable[0]))
    llvm_unreachable("Mode or dimension out of range for CpAsyncBulkTensorPrefetchOp.");

  llvm::Intrinsic::ID id = IDTable[mode][dim];
  if (id == llvm::Intrinsic::not_intrinsic)
    llvm_unreachable("Invalid intrinsic for CpAsyncBulkTensorPrefetchOp.");

}

#define CP_ASYNC_BULK_TENSOR_REDUCE_MODE(op, dim, mode) \
Expand Down
Loading