-
Notifications
You must be signed in to change notification settings - Fork 13.9k
[MLIR][NVVM] Add Ops for tcgen05 cp and shift #127798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add Ops for tcgen05 cp and shift #127798
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Durgadoss R (durga4github) ChangesPR 127669 adds intrinsics for tcgen05.cp/shift. lit tests are added to verify the lowering Full diff: https://github.com/llvm/llvm-project/pull/127798.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..0479ecf14a129 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2810,6 +2810,113 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
}];
}
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+ let summary = "Tcgen05 shift operation";
+ let description = [{
+ The `tcgen05.shift` is an asynchronous instruction which initiates
+ the shifting of 32-byte elements downwards across all the rows,
+ except the last, by one row. The operand `taddr` specifies the base
+ address of the matrix in Tensor Memory whose rows must be down shifted.
+ [For more information refer to the PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
+ }];
+
+ let arguments = (ins LLVM_PointerTensor:$taddr,
+ DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
+
+ let assemblyFormat = "$taddr attr-dict `:` type(operands)";
+
+ string llvmBuilder = [{
+ auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
+ llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
+ llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
+ createIntrinsicCall(builder, id, {$taddr});
+ }];
+}
+
+def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
+def Shape4x256b : I32EnumAttrCase<"SHAPE_4x256b", 1, "shape_4x256b">;
+def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
+def Shape64x128b : I32EnumAttrCase<"SHAPE_64x128b", 3, "shape_64x128b">;
+def Shape32x128b : I32EnumAttrCase<"SHAPE_32x128b", 4, "shape_32x128b">;
+
+def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
+ [Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
+def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
+def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
+def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
+
+def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
+ [Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
+ Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
+def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
+
+def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
+ [FormatB6x16_P32, FormatB4x16_P64]> {
+ let cppNamespace = "::mlir::NVVM";
+ let genSpecializedAttr = 0;
+}
+def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
+ let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+ let summary = "Tcgen05 copy operation";
+ let description = [{
+ Instruction tcgen05.cp initiates an asynchronous copy operation from
+ shared memory to the location specified by the address operand `taddr`
+ in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
+ the matrix descriptor representing the source matrix in the shared memory
+ that needs to be copied.
+
+ usage:
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ group = #nvvm.tcgen05_group<cta_2>,
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+ srcFormat = #nvvm.tcgen05_cp_format<b6x16_p32>
+ }
+ [For more information refer to the PTX ISA]
+ (https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
+ }];
+
+ let arguments = (ins
+ Tcgen05CpShapeAttr:$shape,
+ DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
+ DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
+ OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
+ LLVM_PointerTensor:$taddr,
+ I64:$smem_desc);
+
+ let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
+ let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
+ }];
+
+ string llvmBuilder = [{
+ auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
+ createIntrinsicCall(builder, id, {$taddr, $smem_desc});
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..b145ffde73b29 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
+//===----------------------------------------------------------------------===//
+// Verifier methods for NVVMDialect Ops
+//===----------------------------------------------------------------------===//
+
// This verifier is shared among the following Ops:
// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
// CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
return success();
}
+LogicalResult NVVM::Tcgen05CpOp::verify() {
+ auto mc = getMulticast();
+
+ using SH = Tcgen05CpShape;
+ using MC = Tcgen05CpMulticast;
+ switch (getShape()) {
+ case SH::SHAPE_128x256b:
+ case SH::SHAPE_128x128b:
+ case SH::SHAPE_4x256b:
+ if (mc != MC::NONE)
+ return emitError("Invalid multicast type for tcgen05.cp Op");
+ break;
+ case SH::SHAPE_64x128b:
+ if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
+ return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
+ "warpx2_02_13 for tcgen05.cp Op");
+ break;
+ case SH::SHAPE_32x128b:
+ if (mc != MC::WARPX4)
+ return emitError(
+ "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
+ break;
+ default:
+ return emitError("Invalid shape for tcgen05.cp Op");
+ }
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVMDialect: getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
#define CP_ASYNC_ID_IMPL(mod, size, suffix) \
llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
@@ -1314,6 +1350,47 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
return id;
}
+#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg) \
+ llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
+
+#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta) \
+ is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2) \
+ : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
+
+#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta) \
+ [&]() -> auto { \
+ if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32) \
+ return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta); \
+ if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64) \
+ return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta); \
+ return TCGEN05_CP_2CTA(shape_mc, , is_2cta); \
+ } \
+ ()
+
+llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
+ auto curOp = cast<NVVM::Tcgen05CpOp>(op);
+ bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+ auto srcFmt = curOp.getSrcFormat();
+ auto mc = curOp.getMulticast();
+
+ switch (curOp.getShape()) {
+ case Tcgen05CpShape::SHAPE_128x256b:
+ return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_128x128b:
+ return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_4x256b:
+ return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_32x128b:
+ return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
+ case Tcgen05CpShape::SHAPE_64x128b:
+ return (mc == Tcgen05CpMulticast::WARPX2_01_23)
+ ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
+ : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
+ default:
+ llvm_unreachable("Invalid shape in tcgen05 cp Op");
+ }
+}
+
/// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
/// have ConstantRangeAttr.
static void nvvmInferResultRanges(Operation *op, Value result,
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
new file mode 100644
index 0000000000000..b6fb8c29572a1
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
+llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
+llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
+llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_1>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+
+ llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
+llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_2>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+ }
+ // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %0, i64 %1)
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ group = #nvvm.tcgen05_group<cta_1>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+ srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+ }
+
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
new file mode 100644
index 0000000000000..23f45f6866b06
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
+llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
+ // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
+ nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
+
+ // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
+ nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8957377607dad..4fca7fd801dbe 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
%res = nvvm.cvt.float.to.tf32 %src
llvm.return %res : i32
}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
+ }
+ llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+ // expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
+ nvvm.tcgen05.cp %taddr, %smem_desc {
+ shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+ multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+ }
+ llvm.return
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
486f9e0
to
c53fbe4
Compare
33ce840
to
87ec02a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR looks great, changes are cosmetic. I approve it. Once we address the issues, we can land it.
PR 127669 adds intrinsics for tcgen05.cp/shift. This PR adds NVVM Dialect Ops for the same. lit tests are added to verify the lowering to the intrinsics. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
87ec02a
to
77d5788
Compare
PR #127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.
lit tests are added to verify the lowering
to the intrinsics.