Skip to content

Commit 6bd88bb

Browse files
authored
[MLIR][NVVM] Add Ops for tcgen05 cp and shift (#127798)
PR #127669 adds intrinsics for tcgen05.cp/shift. This PR adds NVVM Dialect Ops for the same. lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
1 parent 73ad78c commit 6bd88bb

File tree

5 files changed

+361
-0
lines changed

5 files changed

+361
-0
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2810,6 +2810,114 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
28102810
}];
28112811
}
28122812

2813+
def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
2814+
let summary = "Tcgen05 shift operation";
2815+
let description = [{
2816+
The `tcgen05.shift` is an asynchronous instruction which initiates
2817+
the shifting of 32-byte elements downwards across all the rows,
2818+
except the last, by one row. The operand `taddr` specifies the base
2819+
address of the matrix in Tensor Memory whose rows must be down shifted.
2820+
2821+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
2822+
}];
2823+
2824+
let arguments = (ins LLVM_PointerTensor:$taddr,
2825+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
2826+
2827+
let assemblyFormat = "$taddr attr-dict `:` type(operands)";
2828+
2829+
string llvmBuilder = [{
2830+
auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
2831+
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
2832+
llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
2833+
createIntrinsicCall(builder, id, {$taddr});
2834+
}];
2835+
}
2836+
2837+
def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
2838+
def Shape4x256b : I32EnumAttrCase<"SHAPE_4x256b", 1, "shape_4x256b">;
2839+
def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
2840+
def Shape64x128b : I32EnumAttrCase<"SHAPE_64x128b", 3, "shape_64x128b">;
2841+
def Shape32x128b : I32EnumAttrCase<"SHAPE_32x128b", 4, "shape_32x128b">;
2842+
2843+
def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
2844+
[Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
2845+
let cppNamespace = "::mlir::NVVM";
2846+
let genSpecializedAttr = 0;
2847+
}
2848+
def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
2849+
let assemblyFormat = "`<` $value `>`";
2850+
}
2851+
2852+
def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
2853+
def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
2854+
def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
2855+
def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
2856+
2857+
def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
2858+
[Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
2859+
Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
2860+
let cppNamespace = "::mlir::NVVM";
2861+
let genSpecializedAttr = 0;
2862+
}
2863+
def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
2864+
let assemblyFormat = "`<` $value `>`";
2865+
}
2866+
2867+
def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
2868+
def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
2869+
2870+
def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
2871+
[FormatB6x16_P32, FormatB4x16_P64]> {
2872+
let cppNamespace = "::mlir::NVVM";
2873+
let genSpecializedAttr = 0;
2874+
}
2875+
def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
2876+
let assemblyFormat = "`<` $value `>`";
2877+
}
2878+
2879+
def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
2880+
let summary = "Tcgen05 copy operation";
2881+
let description = [{
2882+
Instruction tcgen05.cp initiates an asynchronous copy operation from
2883+
shared memory to the location specified by the address operand `taddr`
2884+
in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
2885+
the matrix descriptor representing the source matrix in the shared memory
2886+
that needs to be copied.
2887+
2888+
Example:
2889+
```mlir
2890+
nvvm.tcgen05.cp %taddr, %smem_desc {
2891+
group = #nvvm.tcgen05_group<cta_2>,
2892+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
2893+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
2894+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
2895+
}
2896+
```
2897+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
2898+
}];
2899+
2900+
let arguments = (ins
2901+
Tcgen05CpShapeAttr:$shape,
2902+
DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
2903+
DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
2904+
OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
2905+
LLVM_PointerTensor:$taddr,
2906+
I64:$smem_desc);
2907+
2908+
let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
2909+
let hasVerifier = 1;
2910+
2911+
let extraClassDeclaration = [{
2912+
static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
2913+
}];
2914+
2915+
string llvmBuilder = [{
2916+
auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
2917+
createIntrinsicCall(builder, id, {$taddr, $smem_desc});
2918+
}];
2919+
}
2920+
28132921
//===----------------------------------------------------------------------===//
28142922
// NVVM target attribute.
28152923
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
7575

7676
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
7777

78+
//===----------------------------------------------------------------------===//
79+
// Verifier methods
80+
//===----------------------------------------------------------------------===//
81+
7882
// This verifier is shared among the following Ops:
7983
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
8084
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
11071111
return success();
11081112
}
11091113

1114+
LogicalResult NVVM::Tcgen05CpOp::verify() {
1115+
auto mc = getMulticast();
1116+
1117+
using SH = Tcgen05CpShape;
1118+
using MC = Tcgen05CpMulticast;
1119+
switch (getShape()) {
1120+
case SH::SHAPE_128x256b:
1121+
case SH::SHAPE_128x128b:
1122+
case SH::SHAPE_4x256b:
1123+
if (mc != MC::NONE)
1124+
return emitError("Invalid multicast type for tcgen05.cp Op");
1125+
break;
1126+
case SH::SHAPE_64x128b:
1127+
if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
1128+
return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
1129+
"warpx2_02_13 for tcgen05.cp Op");
1130+
break;
1131+
case SH::SHAPE_32x128b:
1132+
if (mc != MC::WARPX4)
1133+
return emitError(
1134+
"Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
1135+
break;
1136+
default:
1137+
return emitError("Invalid shape for tcgen05.cp Op");
1138+
}
1139+
return success();
1140+
}
1141+
1142+
//===----------------------------------------------------------------------===//
1143+
// getIntrinsicID/getIntrinsicIDAndArgs methods
1144+
//===----------------------------------------------------------------------===//
1145+
11101146
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
11111147
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
11121148

@@ -1314,6 +1350,46 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
13141350
return id;
13151351
}
13161352

1353+
#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
1354+
llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
1355+
1356+
#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
1357+
is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
1358+
: TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
1359+
1360+
#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
1361+
[&]() -> auto { \
1362+
if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
1363+
return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
1364+
if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
1365+
return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
1366+
return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
1367+
}()
1368+
1369+
llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
1370+
auto curOp = cast<NVVM::Tcgen05CpOp>(op);
1371+
bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
1372+
auto srcFmt = curOp.getSrcFormat();
1373+
auto mc = curOp.getMulticast();
1374+
1375+
switch (curOp.getShape()) {
1376+
case Tcgen05CpShape::SHAPE_128x256b:
1377+
return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
1378+
case Tcgen05CpShape::SHAPE_128x128b:
1379+
return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
1380+
case Tcgen05CpShape::SHAPE_4x256b:
1381+
return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
1382+
case Tcgen05CpShape::SHAPE_32x128b:
1383+
return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
1384+
case Tcgen05CpShape::SHAPE_64x128b:
1385+
return (mc == Tcgen05CpMulticast::WARPX2_01_23)
1386+
? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
1387+
: GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
1388+
default:
1389+
llvm_unreachable("Invalid shape in tcgen05 cp Op");
1390+
}
1391+
}
1392+
13171393
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
13181394
/// have ConstantRangeAttr.
13191395
static void nvvmInferResultRanges(Operation *op, Value result,
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
4+
llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
5+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
6+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
7+
8+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
9+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
10+
11+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
12+
nvvm.tcgen05.cp %taddr, %smem_desc {
13+
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
14+
group = #nvvm.tcgen05_group<cta_2>,
15+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
16+
}
17+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
18+
nvvm.tcgen05.cp %taddr, %smem_desc {
19+
shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
20+
group = #nvvm.tcgen05_group<cta_2>,
21+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
22+
}
23+
llvm.return
24+
}
25+
26+
// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
27+
llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
28+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
29+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
30+
31+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
32+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
33+
34+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
35+
nvvm.tcgen05.cp %taddr, %smem_desc {
36+
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
37+
group = #nvvm.tcgen05_group<cta_2>,
38+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
39+
}
40+
// CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
41+
nvvm.tcgen05.cp %taddr, %smem_desc {
42+
shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
43+
group = #nvvm.tcgen05_group<cta_2>,
44+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
45+
}
46+
llvm.return
47+
}
48+
49+
// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
50+
llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
51+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
52+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
53+
54+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
55+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
56+
57+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
58+
nvvm.tcgen05.cp %taddr, %smem_desc {
59+
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
60+
group = #nvvm.tcgen05_group<cta_2>,
61+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
62+
}
63+
// CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
64+
nvvm.tcgen05.cp %taddr, %smem_desc {
65+
shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
66+
group = #nvvm.tcgen05_group<cta_2>,
67+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
68+
}
69+
llvm.return
70+
}
71+
72+
// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
73+
llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
74+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
75+
nvvm.tcgen05.cp %taddr, %smem_desc {
76+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
77+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
78+
}
79+
80+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
81+
nvvm.tcgen05.cp %taddr, %smem_desc {
82+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
83+
group = #nvvm.tcgen05_group<cta_2>,
84+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
85+
}
86+
87+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
88+
nvvm.tcgen05.cp %taddr, %smem_desc {
89+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
90+
group = #nvvm.tcgen05_group<cta_1>,
91+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
92+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
93+
}
94+
// CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
95+
nvvm.tcgen05.cp %taddr, %smem_desc {
96+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
97+
group = #nvvm.tcgen05_group<cta_2>,
98+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
99+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
100+
}
101+
102+
llvm.return
103+
}
104+
105+
// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
106+
llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
107+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
108+
nvvm.tcgen05.cp %taddr, %smem_desc {
109+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
110+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
111+
}
112+
113+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
114+
nvvm.tcgen05.cp %taddr, %smem_desc {
115+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
116+
group = #nvvm.tcgen05_group<cta_2>,
117+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
118+
}
119+
120+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
121+
nvvm.tcgen05.cp %taddr, %smem_desc {
122+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
123+
group = #nvvm.tcgen05_group<cta_2>,
124+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
125+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
126+
}
127+
// CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %{{.*}}, i64 %{{.*}})
128+
nvvm.tcgen05.cp %taddr, %smem_desc {
129+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
130+
group = #nvvm.tcgen05_group<cta_1>,
131+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
132+
srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
133+
}
134+
135+
llvm.return
136+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
2+
3+
// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
4+
llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
5+
// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
6+
nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
7+
8+
// CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
9+
nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
10+
llvm.return
11+
}

mlir/test/Target/LLVMIR/nvvmir-invalid.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
122122
%res = nvvm.cvt.float.to.tf32 %src
123123
llvm.return %res : i32
124124
}
125+
126+
// -----
127+
128+
llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
129+
// expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
130+
nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
131+
llvm.return
132+
}
133+
134+
// -----
135+
136+
llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
137+
// expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
138+
nvvm.tcgen05.cp %taddr, %smem_desc {
139+
shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
140+
multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
141+
}
142+
llvm.return
143+
}
144+
145+
// -----
146+
147+
llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
148+
// expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
149+
nvvm.tcgen05.cp %taddr, %smem_desc {
150+
shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
151+
multicast = #nvvm.tcgen05_cp_multicast<warpx4>
152+
}
153+
llvm.return
154+
}

0 commit comments

Comments
 (0)