Skip to content

Conversation

@schwarzschild-radius
Copy link
Contributor

@schwarzschild-radius schwarzschild-radius commented Jan 29, 2025

This commit adds support for the following ldmatrix extensions introduced in PTX 8.6

  • Support for m16n16 with b8 type with mandatory transpose
  • Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a

Please refer the PTX ISA for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-backend-nvptx

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for the following ldmatrix extensions introduced in PTX 8.6

  • Support for m16n16 with b8 type with mandatory transpose
  • Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a


Full diff: https://github.com/llvm/llvm-project/pull/124899.diff

8 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+32-7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+15-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+25-2)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py (+16)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py (+16)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py (+16)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+49-3)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string frag = Frag;
   string ptx_elt_type = PtxEltType;
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
+  string gf = Geom#":"#Frag;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
     // mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
 
     // ldmatrix b16 -> s32 @ m8n8
-    !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
-    !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
-    !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+    !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+    !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
 
   list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
     ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
-  list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+  list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+    ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+    ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+                                                 ldmatrix_geom_m16n16_ops,
+                                                 ldmatrix_geom_m8n16_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
 // if NVVM_LDMATRIX_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   string g = frag.geom;
   string t = frag.ptx_elt_type;
 
   bit ret = !cond(
-    // Only currently support m8n8 and b16
     !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
     true: false
   );
 }
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
 
 foreach transposed = [0, 1] in {
   foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
-    if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+    if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
       def LDMATRIX_NAME<frag, transposed>.record
         : NVVM_LDMATRIX<frag, transposed>;
     }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
 // Explicit records for arch-accelerated SM versions
 def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
 def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
     !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
     !eq(ptx_elt_type, "b16") : Int32Regs,
+    !eq(ptx_elt_type, "b8") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
     !eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op,"ldmatrix"),
          !eq(ptx_elt_type,"b16"),
-         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
     foreach space = [".shared", ""] in {
       foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
         foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
-          if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+          if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
             def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
                             addr>;
       } // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_100a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_101a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_120a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..c1826fc561834e 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
             "f64": "double",
             "s32": "i32",
             "b16": "i32",
+            "b8": "i32",
+            "b8x16.b6x16_p32" : "i32",
+            "b8x16.b4x16_p64" : "i32",
             "s8": "i32",
             "u8": "i32",
             "s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n8:x1:b16": 1,
             "m8n8:x2:b16": 2,
             "m8n8:x4:b16": 4,
+            "m16n16:x1:b8": 2,
+            "m16n16:x2:b8": 4,
+            "m16n16:x1:b8x16.b6x16_p32": 2,
+            "m16n16:x2:b8x16.b6x16_p32": 4,
+            "m16n16:x1:b8x16.b4x16_p64": 2,
+            "m16n16:x2:b8x16.b4x16_p64": 4,
+            "m8n16:x1:b8x16.b6x16_p32" : 1,
+            "m8n16:x2:b8x16.b6x16_p32" : 2,
+            "m8n16:x4:b8x16.b6x16_p32" : 4,
+            "m8n16:x1:b8x16.b4x16_p64" : 1,
+            "m8n16:x2:b8x16.b4x16_p64" : 2,
+            "m8n16:x4:b8x16.b4x16_p64" : 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -289,7 +304,9 @@ def get_ldst_ops(kind):
 
 
 def get_ldmatrix_ops():
-    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+    return (make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+            + make_ldmatrix_ops(["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"])
+            + make_ldmatrix_ops(["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]))
 
 
 def is_wmma_geom_supported(geom):
@@ -330,8 +347,20 @@ def is_mma_geom_supported(geom):
 def is_ldmatrix_geom_supported(geom):
     if geom in ["m8n8"]:
         return ptx_version >= 65 and gpu_arch >= 75
+    elif geom in ["m16n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    elif geom in ["m8n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
     assert False  # Unexpected geometry.
 
+def is_ldmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n16"]:
+        return trans == ".trans"
+    elif geom in ["m8n16"]:
+        return trans == ""
+    assert False  # Unexpected geometry.
 
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +446,11 @@ def is_ldst_variant_supported(frag, layout):
     return True
 
 
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
     if not (
         is_type_supported(frag.mma_type.ptx_type)
         and is_ldmatrix_geom_supported(frag.geom)
+        and is_ldmatrix_trans_supported(frag.geom, trans)
     ):
         return False
     return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +683,7 @@ def gen_ldmatrix_tests():
         ["", ".shared"],
         ["", ".trans"],
     ):
-        if not is_ldmatrix_variant_supported(frag):
+        if not is_ldmatrix_variant_supported(frag, trans):
             continue
 
         params = {
@@ -944,6 +974,19 @@ def gen_check_unsupported_ops(items):
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
 
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1040,16 @@ def gen_tests():
 def main():
     global ptx_version
     global gpu_arch
+    global aa
     parser = argparse.ArgumentParser()
     parser.add_argument("--ptx", type=int, default=60)
     parser.add_argument("--gpu-arch", type=int, default=70)
+    parser.add_argument("--aa", action='store_true')
     args = parser.parse_args()
 
     ptx_version = args.ptx
     gpu_arch = args.gpu_arch
+    aa = args.aa
 
     gen_tests()
 

@llvmbot
Copy link
Member

llvmbot commented Jan 29, 2025

@llvm/pr-subscribers-llvm-ir

Author: Pradeep Kumar (schwarzschild-radius)

Changes

This commit adds support for the following ldmatrix extensions introduced in PTX 8.6

  • Support for m16n16 with b8 type with mandatory transpose
  • Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a


Full diff: https://github.com/llvm/llvm-project/pull/124899.diff

8 Files Affected:

  • (modified) llvm/include/llvm/IR/IntrinsicsNVVM.td (+32-7)
  • (modified) llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp (+15-3)
  • (modified) llvm/lib/Target/NVPTX/NVPTXInstrInfo.td (+1)
  • (modified) llvm/lib/Target/NVPTX/NVPTXIntrinsics.td (+25-2)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py (+16)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py (+16)
  • (added) llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py (+16)
  • (modified) llvm/test/CodeGen/NVPTX/wmma.py (+49-3)
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
   string frag = Frag;
   string ptx_elt_type = PtxEltType;
   string gft = Geom#":"#Frag#":"#ptx_elt_type;
+  string gf = Geom#":"#Frag;
   string ft = frag#":"#ptx_elt_type;
   list<LLVMType> regs = !cond(
     // mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
     !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
 
     // ldmatrix b16 -> s32 @ m8n8
-    !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
-    !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
-    !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+    !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+    !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+    // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+    !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+    !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+    !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
   );
 }
 
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
 
   list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
     ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
-  list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+  list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+    ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+    ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+  list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+                                                 ldmatrix_geom_m16n16_ops,
+                                                 ldmatrix_geom_m8n16_ops);
 }
 
 def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
 // if NVVM_LDMATRIX_SUPPORTED<...>.ret then
 //   def : FOO<>; // The record will only be defined for supported ops.
 //
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
   string g = frag.geom;
   string t = frag.ptx_elt_type;
 
   bit ret = !cond(
-    // Only currently support m8n8 and b16
     !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+    !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+    !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
     true: false
   );
 }
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
 
 foreach transposed = [0, 1] in {
   foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
-    if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+    if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
       def LDMATRIX_NAME<frag, transposed>.record
         : NVVM_LDMATRIX<frag, transposed>;
     }
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
   case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v4i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
   case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::i32;
     Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
   case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
   case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
-  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+  case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
     Info.opc = ISD::INTRINSIC_W_CHAIN;
     Info.memVT = MVT::v2i32;
     Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
 class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
 class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
 
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
 // Explicit records for arch-accelerated SM versions
 def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
 def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
     !eq(ptx_elt_type, "tf32") : Int32Regs,
     !eq(ptx_elt_type, "s32") : Int32Regs,
     !eq(ptx_elt_type, "b16") : Int32Regs,
+    !eq(ptx_elt_type, "b8") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+    !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
     !eq(ptx_elt_type, "s8") : Int32Regs,
     !eq(ptx_elt_type, "u8") : Int32Regs,
     !eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
 
     !and(!eq(op,"ldmatrix"),
          !eq(ptx_elt_type,"b16"),
-         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+         !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+    !and(!eq(op,"ldmatrix"),
+         !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+         !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
 
   // template DAGs for instruction inputs/output.
   dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs  = {
     foreach space = [".shared", ""] in {
       foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
         foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
-          if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+          if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
             def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
                             addr>;
       } // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_100a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_101a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN:           --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{                                                  \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN:           | %ptxas-verify -arch=sm_120a                              \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..c1826fc561834e 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
             "f64": "double",
             "s32": "i32",
             "b16": "i32",
+            "b8": "i32",
+            "b8x16.b6x16_p32" : "i32",
+            "b8x16.b4x16_p64" : "i32",
             "s8": "i32",
             "u8": "i32",
             "s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
             "m8n8:x1:b16": 1,
             "m8n8:x2:b16": 2,
             "m8n8:x4:b16": 4,
+            "m16n16:x1:b8": 2,
+            "m16n16:x2:b8": 4,
+            "m16n16:x1:b8x16.b6x16_p32": 2,
+            "m16n16:x2:b8x16.b6x16_p32": 4,
+            "m16n16:x1:b8x16.b4x16_p64": 2,
+            "m16n16:x2:b8x16.b4x16_p64": 4,
+            "m8n16:x1:b8x16.b6x16_p32" : 1,
+            "m8n16:x2:b8x16.b6x16_p32" : 2,
+            "m8n16:x4:b8x16.b6x16_p32" : 4,
+            "m8n16:x1:b8x16.b4x16_p64" : 1,
+            "m8n16:x2:b8x16.b4x16_p64" : 2,
+            "m8n16:x4:b8x16.b4x16_p64" : 4,
         }.get(
             "%s:%s:%s" % (geom, frag, ptx_elt_type),
             {
@@ -289,7 +304,9 @@ def get_ldst_ops(kind):
 
 
 def get_ldmatrix_ops():
-    return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+    return (make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+            + make_ldmatrix_ops(["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"])
+            + make_ldmatrix_ops(["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]))
 
 
 def is_wmma_geom_supported(geom):
@@ -330,8 +347,20 @@ def is_mma_geom_supported(geom):
 def is_ldmatrix_geom_supported(geom):
     if geom in ["m8n8"]:
         return ptx_version >= 65 and gpu_arch >= 75
+    elif geom in ["m16n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
+    elif geom in ["m8n16"]:
+        return ptx_version >= 86 and gpu_arch >= 100 and aa
     assert False  # Unexpected geometry.
 
+def is_ldmatrix_trans_supported(geom, trans):
+    if geom in ["m8n8"]:
+        return True
+    elif geom in ["m16n16"]:
+        return trans == ".trans"
+    elif geom in ["m8n16"]:
+        return trans == ""
+    assert False  # Unexpected geometry.
 
 def is_type_supported(ptx_type):
     if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +446,11 @@ def is_ldst_variant_supported(frag, layout):
     return True
 
 
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
     if not (
         is_type_supported(frag.mma_type.ptx_type)
         and is_ldmatrix_geom_supported(frag.geom)
+        and is_ldmatrix_trans_supported(frag.geom, trans)
     ):
         return False
     return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +683,7 @@ def gen_ldmatrix_tests():
         ["", ".shared"],
         ["", ".trans"],
     ):
-        if not is_ldmatrix_variant_supported(frag):
+        if not is_ldmatrix_variant_supported(frag, trans):
             continue
 
         params = {
@@ -944,6 +974,19 @@ def gen_check_unsupported_ops(items):
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
 ; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
 
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
 ; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
 ; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
 ; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1040,16 @@ def gen_tests():
 def main():
     global ptx_version
     global gpu_arch
+    global aa
     parser = argparse.ArgumentParser()
     parser.add_argument("--ptx", type=int, default=60)
     parser.add_argument("--gpu-arch", type=int, default=70)
+    parser.add_argument("--aa", action='store_true')
     args = parser.parse_args()
 
     ptx_version = args.ptx
     gpu_arch = args.gpu_arch
+    aa = args.aa
 
     gen_tests()
 

@github-actions
Copy link

github-actions bot commented Jan 29, 2025

✅ With the latest revision this PR passed the Python code formatter.

@schwarzschild-radius schwarzschild-radius force-pushed the ldmatrix_sm100a_nvptx branch 2 times, most recently from cfabf9c to 52f3619 Compare January 29, 2025 14:53
Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sorry about the delay. I apparently archived PR notification out of inbox before reading it.
In case you see no response from me in reasonable time in the future, feel free to ping me to bring it to my attention. Github should have some sort of opt-in auto-nag feature for assigned reviewers.

@schwarzschild-radius
Copy link
Contributor Author

LGTM.

Sorry about the delay. I apparently archived PR notification out of inbox before reading it. In case you see no response from me in reasonable time in the future, feel free to ping me to bring it to my attention. Github should have some sort of opt-in auto-nag feature for assigned reviewers.

No Problem, Artem. Sorry for delay in responding. For some reason, GitHub stopped sending PR update notifications on the mobile app. I'll be more proactive with reviews going forward.

@schwarzschild-radius
Copy link
Contributor Author

@Artem-B I have addressed all your comments. Please let me know if there's anything else needed or if I can proceed with merging the PR. Thanks!

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with a small nit.

This commit adds support for the following ldmatrix extensions introduced in PTX 8.6
- Support for m16n16 with b8 type with mandatory transpose
- Support for m16n16 with m8n16 with source and desitination formats

The above extensions are only supported on sm_100a, sm_101a, sm_120a

Please refer the PTX ISA for more information:
https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
@schwarzschild-radius schwarzschild-radius merged commit 52e7ca9 into llvm:main Mar 17, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants