Skip to content

[MLIR][NVVM] Add Ops for tcgen05 cp and shift #127798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2810,6 +2810,114 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
}];
}

def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
let summary = "Tcgen05 shift operation";
let description = [{
The `tcgen05.shift` is an asynchronous instruction which initiates
the shifting of 32-byte elements downwards across all the rows,
except the last, by one row. The operand `taddr` specifies the base
address of the matrix in Tensor Memory whose rows must be down shifted.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
}];

let arguments = (ins LLVM_PointerTensor:$taddr,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);

let assemblyFormat = "$taddr attr-dict `:` type(operands)";

string llvmBuilder = [{
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
createIntrinsicCall(builder, id, {$taddr});
}];
}

def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
def Shape4x256b : I32EnumAttrCase<"SHAPE_4x256b", 1, "shape_4x256b">;
def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
def Shape64x128b : I32EnumAttrCase<"SHAPE_64x128b", 3, "shape_64x128b">;
def Shape32x128b : I32EnumAttrCase<"SHAPE_32x128b", 4, "shape_32x128b">;

def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
[Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
let assemblyFormat = "`<` $value `>`";
}

def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;

def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
[Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
let assemblyFormat = "`<` $value `>`";
}

def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;

def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
[FormatB6x16_P32, FormatB4x16_P64]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
let assemblyFormat = "`<` $value `>`";
}

def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
let summary = "Tcgen05 copy operation";
let description = [{
Instruction tcgen05.cp initiates an asynchronous copy operation from
shared memory to the location specified by the address operand `taddr`
in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
the matrix descriptor representing the source matrix in the shared memory
that needs to be copied.

Example:
```mlir
nvvm.tcgen05.cp %taddr, %smem_desc {
group = #nvvm.tcgen05_group<cta_2>,
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
```
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
}];

let arguments = (ins
Tcgen05CpShapeAttr:$shape,
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
LLVM_PointerTensor:$taddr,
I64:$smem_desc);

let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
let hasVerifier = 1;

let extraClassDeclaration = [{
static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
}];

string llvmBuilder = [{
auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
createIntrinsicCall(builder, id, {$taddr, $smem_desc});
}];
}

//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
Expand Down
76 changes: 76 additions & 0 deletions mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {

void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }

//===----------------------------------------------------------------------===//
// Verifier methods
//===----------------------------------------------------------------------===//

// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
Expand Down Expand Up @@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}

LogicalResult NVVM::Tcgen05CpOp::verify() {
auto mc = getMulticast();

using SH = Tcgen05CpShape;
using MC = Tcgen05CpMulticast;
switch (getShape()) {
case SH::SHAPE_128x256b:
case SH::SHAPE_128x128b:
case SH::SHAPE_4x256b:
if (mc != MC::NONE)
return emitError("Invalid multicast type for tcgen05.cp Op");
break;
case SH::SHAPE_64x128b:
if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
"warpx2_02_13 for tcgen05.cp Op");
break;
case SH::SHAPE_32x128b:
if (mc != MC::WARPX4)
return emitError(
"Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
break;
default:
return emitError("Invalid shape for tcgen05.cp Op");
}
return success();
}

//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//

#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix

Expand Down Expand Up @@ -1314,6 +1350,46 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return id;
}

#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg

#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)

#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
[&]() -> auto { \
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
}()

llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
auto srcFmt = curOp.getSrcFormat();
auto mc = curOp.getMulticast();

switch (curOp.getShape()) {
case Tcgen05CpShape::SHAPE_128x256b:
return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_128x128b:
return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_4x256b:
return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_32x128b:
return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
case Tcgen05CpShape::SHAPE_64x128b:
return (mc == Tcgen05CpMulticast::WARPX2_01_23)
? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
: GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
default:
llvm_unreachable("Invalid shape in tcgen05 cp Op");
}
}

/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
Expand Down
136 changes: 136 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
}

// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
}

// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}

// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
group = #nvvm.tcgen05_group<cta_2>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}
llvm.return
}

// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_1>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
group = #nvvm.tcgen05_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}

llvm.return
}

// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
}

// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_2>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
}
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
group = #nvvm.tcgen05_group<cta_1>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
}

llvm.return
}
11 changes: 11 additions & 0 deletions mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s

// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
nvvm.tcgen05.shift %taddr : !llvm.ptr<6>

// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
llvm.return
}
30 changes: 30 additions & 0 deletions mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
%res = nvvm.cvt.float.to.tf32 %src
llvm.return %res : i32
}

// -----

llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
llvm.return
}

// -----

llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
}
llvm.return
}

// -----

llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
// expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
nvvm.tcgen05.cp %taddr, %smem_desc {
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
}
llvm.return
}