Skip to content

[MLIR][NVVM] Add Ops for tcgen05 cp and shift #127798

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 21, 2025

Conversation

durga4github
Copy link
Contributor

@durga4github durga4github commented Feb 19, 2025

PR #127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.

lit tests are added to verify the lowering
to the intrinsics.

@durga4github durga4github requested a review from grypp as a code owner February 19, 2025 14:00
@durga4github durga4github removed the request for review from grypp February 19, 2025 14:00
@durga4github durga4github requested a review from grypp February 19, 2025 14:00
@llvmbot
Copy link
Member

llvmbot commented Feb 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

PR 127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.

lit tests are added to verify the lowering
to the intrinsics.


Full diff: https://github.com/llvm/llvm-project/pull/127798.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+107)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+77)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir (+137)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir (+12)
  • (modified) mlir/test/Target/LLVMIR/nvvmir-invalid.mlir (+30)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0de5a87e72c3f..0479ecf14a129 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2810,6 +2810,113 @@ def NVVM_Tcgen05CommitOp : NVVM_Op<"tcgen05.commit"> {
   }];
 }
 
+def NVVM_Tcgen05ShiftOp : NVVM_Op<"tcgen05.shift"> {
+  let summary = "Tcgen05 shift operation";
+  let description = [{
+    The `tcgen05.shift` is an asynchronous instruction which initiates
+    the shifting of 32-byte elements downwards across all the rows,
+    except the last, by one row. The operand `taddr` specifies the base
+    address of the matrix in Tensor Memory whose rows must be down shifted.
+    [For more information refer to the PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-instructions-tcgen05-shift)
+  }];
+
+  let arguments = (ins LLVM_PointerTensor:$taddr,
+    DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group);
+
+  let assemblyFormat = "$taddr attr-dict `:` type(operands)";
+
+  string llvmBuilder = [{
+    auto id = ($group == NVVM::Tcgen05GroupKind::CTA_1) ?
+      llvm::Intrinsic::nvvm_tcgen05_shift_down_cg1 :
+      llvm::Intrinsic::nvvm_tcgen05_shift_down_cg2;
+    createIntrinsicCall(builder, id, {$taddr});
+  }];
+}
+
+def Shape128x256b : I32EnumAttrCase<"SHAPE_128x256b", 0, "shape_128x256b">;
+def Shape4x256b   : I32EnumAttrCase<"SHAPE_4x256b",   1, "shape_4x256b">;
+def Shape128x128b : I32EnumAttrCase<"SHAPE_128x128b", 2, "shape_128x128b">;
+def Shape64x128b  : I32EnumAttrCase<"SHAPE_64x128b",  3, "shape_64x128b">;
+def Shape32x128b  : I32EnumAttrCase<"SHAPE_32x128b",  4, "shape_32x128b">;
+
+def Tcgen05CpShape : I32EnumAttr<"Tcgen05CpShape", "tcgen05 cp shapes",
+  [Shape128x256b, Shape4x256b, Shape128x128b, Shape64x128b, Shape32x128b]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpShapeAttr : EnumAttr<NVVM_Dialect, Tcgen05CpShape, "tcgen05_cp_shape"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def Tcgen05CpMulticastNone: I32EnumAttrCase<"NONE", 0, "none">;
+def Tcgen05CpMulticastWarpx2_02_13: I32EnumAttrCase<"WARPX2_02_13", 1, "warpx2_02_13">;
+def Tcgen05CpMulticastWarpx2_01_23: I32EnumAttrCase<"WARPX2_01_23", 2, "warpx2_01_23">;
+def Tcgen05CpMulticastWarpx4: I32EnumAttrCase<"WARPX4", 3, "warpx4">;
+
+def Tcgen05CpMulticast : I32EnumAttr<"Tcgen05CpMulticast", "tcgen05 cp multicast",
+  [Tcgen05CpMulticastNone, Tcgen05CpMulticastWarpx2_02_13,
+   Tcgen05CpMulticastWarpx2_01_23, Tcgen05CpMulticastWarpx4]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpMulticastAttr : EnumAttr<NVVM_Dialect, Tcgen05CpMulticast, "tcgen05_cp_multicast"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def FormatB6x16_P32: I32EnumAttrCase<"B6x16_P32", 0, "b6x16_p32">;
+def FormatB4x16_P64: I32EnumAttrCase<"B4x16_P64", 1, "b4x16_p64">;
+
+def Tcgen05CpSrcFormat : I32EnumAttr<"Tcgen05CpSrcFormat", "tcgen05 cp source format",
+  [FormatB6x16_P32, FormatB4x16_P64]> {
+    let cppNamespace = "::mlir::NVVM";
+    let genSpecializedAttr = 0;
+}
+def Tcgen05CpSrcFormatAttr : EnumAttr<NVVM_Dialect, Tcgen05CpSrcFormat, "tcgen05_cp_src_fmt"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
+def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp"> {
+  let summary = "Tcgen05 copy operation";
+  let description = [{
+    Instruction tcgen05.cp initiates an asynchronous copy operation from
+    shared memory to the location specified by the address operand `taddr`
+    in the Tensor Memory. The 64-bit register operand `smem_desc` specifies
+    the matrix descriptor representing the source matrix in the shared memory
+    that needs to be copied.
+
+    usage:
+      nvvm.tcgen05.cp %taddr, %smem_desc {
+        group = #nvvm.tcgen05_group<cta_2>,
+        shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+        multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+        srcFormat = #nvvm.tcgen05_cp_format<b6x16_p32>
+      }
+    [For more information refer to the PTX ISA]
+    (https://docs.nvidia.com/cuda/parallel-thread-execution/#tensorcore-5th-generation-instructions-tcgen05-cp)
+  }];
+
+  let arguments = (ins
+    Tcgen05CpShapeAttr:$shape,
+    DefaultValuedAttr<Tcgen05GroupKindAttr, "Tcgen05GroupKind::CTA_1">:$group,
+    DefaultValuedAttr<Tcgen05CpMulticastAttr, "Tcgen05CpMulticast::NONE">:$multicast,
+    OptionalAttr<Tcgen05CpSrcFormatAttr>:$srcFormat,
+    LLVM_PointerTensor:$taddr,
+    I64:$smem_desc);
+
+  let assemblyFormat = "$taddr`,` $smem_desc attr-dict";
+  let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    static llvm::Intrinsic::ID getIntrinsicID(Operation &op);
+  }];
+
+  string llvmBuilder = [{
+    auto id = NVVM::Tcgen05CpOp::getIntrinsicID(*op);
+    createIntrinsicCall(builder, id, {$taddr, $smem_desc});
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM target attribute.
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 62f0c21338111..b145ffde73b29 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -75,6 +75,10 @@ ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
 
 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
 
+//===----------------------------------------------------------------------===//
+// Verifier methods for NVVMDialect Ops
+//===----------------------------------------------------------------------===//
+
 // This verifier is shared among the following Ops:
 // CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
 // CpAsyncBulkTensorPrefetchOp (TMA Prefetch)
@@ -1107,6 +1111,38 @@ LogicalResult NVVM::BarrierOp::verify() {
   return success();
 }
 
+LogicalResult NVVM::Tcgen05CpOp::verify() {
+  auto mc = getMulticast();
+
+  using SH = Tcgen05CpShape;
+  using MC = Tcgen05CpMulticast;
+  switch (getShape()) {
+  case SH::SHAPE_128x256b:
+  case SH::SHAPE_128x128b:
+  case SH::SHAPE_4x256b:
+    if (mc != MC::NONE)
+      return emitError("Invalid multicast type for tcgen05.cp Op");
+    break;
+  case SH::SHAPE_64x128b:
+    if (mc != MC::WARPX2_01_23 && mc != MC::WARPX2_02_13)
+      return emitError("Shape 64x128b requires multicast warpx2_01_23 or "
+                       "warpx2_02_13 for tcgen05.cp Op");
+    break;
+  case SH::SHAPE_32x128b:
+    if (mc != MC::WARPX4)
+      return emitError(
+          "Shape 32x128b requires multicast warpx4 for tcgen05.cp Op");
+    break;
+  default:
+    return emitError("Invalid shape for tcgen05.cp Op");
+  }
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// NVVMDialect: getIntrinsicID/getIntrinsicIDAndArgs methods
+//===----------------------------------------------------------------------===//
+
 #define CP_ASYNC_ID_IMPL(mod, size, suffix)                                    \
   llvm::Intrinsic::nvvm_cp_async_##mod##_shared_global_##size##suffix
 
@@ -1314,6 +1350,47 @@ Tcgen05CommitOp::getIntrinsicIDAndArgs(Operation &op,
   return id;
 }
 
+#define TCGEN05_CP_IMPL(shape_mc, src_fmt, cg)                                 \
+  llvm::Intrinsic::nvvm_tcgen05_cp##shape_mc##src_fmt##cg
+
+#define TCGEN05_CP_2CTA(shape_mc, src_fmt, is_2cta)                            \
+  is_2cta ? TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg2)                           \
+          : TCGEN05_CP_IMPL(shape_mc, src_fmt, _cg1)
+
+#define GET_TCGEN05_CP_ID(shape_mc, src_fmt, is_2cta)                          \
+  [&]() -> auto {                                                              \
+    if (src_fmt == Tcgen05CpSrcFormat::B6x16_P32)                              \
+      return TCGEN05_CP_2CTA(shape_mc, _b6x16_p32, is_2cta);                   \
+    if (src_fmt == Tcgen05CpSrcFormat::B4x16_P64)                              \
+      return TCGEN05_CP_2CTA(shape_mc, _b4x16_p64, is_2cta);                   \
+    return TCGEN05_CP_2CTA(shape_mc, , is_2cta);                               \
+  }                                                                            \
+  ()
+
+llvm::Intrinsic::ID Tcgen05CpOp::getIntrinsicID(Operation &op) {
+  auto curOp = cast<NVVM::Tcgen05CpOp>(op);
+  bool is2CTA = curOp.getGroup() == Tcgen05GroupKind::CTA_2;
+  auto srcFmt = curOp.getSrcFormat();
+  auto mc = curOp.getMulticast();
+
+  switch (curOp.getShape()) {
+  case Tcgen05CpShape::SHAPE_128x256b:
+    return GET_TCGEN05_CP_ID(_128x256b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_128x128b:
+    return GET_TCGEN05_CP_ID(_128x128b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_4x256b:
+    return GET_TCGEN05_CP_ID(_4x256b, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_32x128b:
+    return GET_TCGEN05_CP_ID(_32x128b_warpx4, srcFmt, is2CTA);
+  case Tcgen05CpShape::SHAPE_64x128b:
+    return (mc == Tcgen05CpMulticast::WARPX2_01_23)
+               ? GET_TCGEN05_CP_ID(_64x128b_warpx2_01_23, srcFmt, is2CTA)
+               : GET_TCGEN05_CP_ID(_64x128b_warpx2_02_13, srcFmt, is2CTA);
+  default:
+    llvm_unreachable("Invalid shape in tcgen05 cp Op");
+  }
+}
+
 /// Infer the result ranges for the NVVM SpecialRangeableRegisterOp that might
 /// have ConstantRangeAttr.
 static void nvvmInferResultRanges(Operation *op, Value result,
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
new file mode 100644
index 0000000000000..b6fb8c29572a1
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-cp.mlir
@@ -0,0 +1,137 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x256b
+llvm.func @nvvm_tcgen05_cp_128x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_4x256b
+llvm.func @nvvm_tcgen05_cp_4x256b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_4x256b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.4x256b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_4x256b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_128x128b
+llvm.func @nvvm_tcgen05_cp_128x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x128b>, group = #nvvm.tcgen05_group<cta_2>}
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.128x128b.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_128x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_64x128b
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_02_13.b4x16_p64.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_1>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.64x128b_warpx2_01_23.b6x16_p32.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+
+  llvm.return
+}
+
+// CHECK-LABEL: @nvvm_tcgen05_cp_32x128b
+llvm.func @nvvm_tcgen05_cp_32x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b4x16_p64.cg2(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_2>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b4x16_p64>
+  }
+  // CHECK: call void @llvm.nvvm.tcgen05.cp.32x128b_warpx4.b6x16_p32.cg1(ptr addrspace(6) %0, i64 %1)
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    group = #nvvm.tcgen05_group<cta_1>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>,
+    srcFormat = #nvvm.tcgen05_cp_src_fmt<b6x16_p32>
+  }
+
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
new file mode 100644
index 0000000000000..23f45f6866b06
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-shift.mlir
@@ -0,0 +1,12 @@
+// RUN: mlir-opt -split-input-file -verify-diagnostics %s
+// RUN: mlir-translate -mlir-to-llvmir -split-input-file -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @llvm_nvvm_tcgen05_shift
+llvm.func @llvm_nvvm_tcgen05_shift(%taddr : !llvm.ptr<6>) {
+  // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg1(ptr addrspace(6) %{{.*}})
+  nvvm.tcgen05.shift %taddr : !llvm.ptr<6>
+
+  // CHECK: call void @llvm.nvvm.tcgen05.shift.down.cg2(ptr addrspace(6) %{{.*}})
+  nvvm.tcgen05.shift %taddr {group = #nvvm.tcgen05_group<cta_2>} : !llvm.ptr<6>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
index 8957377607dad..4fca7fd801dbe 100644
--- a/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir-invalid.mlir
@@ -122,3 +122,33 @@ llvm.func @convert_float_to_tf32_no_rnd_mode(%src : f32) -> i32 {
   %res = nvvm.cvt.float.to.tf32 %src
   llvm.return %res : i32
 }
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_128x256b_mc(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Invalid multicast type for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {shape = #nvvm.tcgen05_cp_shape<shape_128x256b>, multicast = #nvvm.tcgen05_cp_multicast<warpx2_02_13>}
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_32x128b_wx2(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Shape 32x128b requires multicast warpx4 for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_32x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx2_01_23>
+  }
+  llvm.return
+}
+
+// -----
+
+llvm.func @nvvm_tcgen05_cp_64x128b(%taddr : !llvm.ptr<6>, %smem_desc : i64) {
+  // expected-error @below {{Shape 64x128b requires multicast warpx2_01_23 or warpx2_02_13 for tcgen05.cp Op}}
+  nvvm.tcgen05.cp %taddr, %smem_desc {
+    shape = #nvvm.tcgen05_cp_shape<shape_64x128b>,
+    multicast = #nvvm.tcgen05_cp_multicast<warpx4>
+  }
+  llvm.return
+}

Copy link

github-actions bot commented Feb 19, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@durga4github durga4github force-pushed the durgadossr/mlir_tcgen05_cp branch from 486f9e0 to c53fbe4 Compare February 19, 2025 16:28
@durga4github durga4github force-pushed the durgadossr/mlir_tcgen05_cp branch 2 times, most recently from 33ce840 to 87ec02a Compare February 20, 2025 11:00
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks great, changes are cosmetic. I approve it. Once we address the issues, we can land it.

PR 127669 adds intrinsics for tcgen05.cp/shift.
This PR adds NVVM Dialect Ops for the same.

lit tests are added to verify the lowering
to the intrinsics.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@durga4github durga4github force-pushed the durgadossr/mlir_tcgen05_cp branch from 87ec02a to 77d5788 Compare February 20, 2025 13:35
@durga4github durga4github merged commit 6bd88bb into llvm:main Feb 21, 2025
8 checks passed
@durga4github durga4github deleted the durgadossr/mlir_tcgen05_cp branch February 21, 2025 10:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants