-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 #124899
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6 #124899
Conversation
|
@llvm/pr-subscribers-backend-nvptx Author: Pradeep Kumar (schwarzschild-radius) ChangesThis commit adds support for the following ldmatrix extensions introduced in PTX 8.6
The above extensions are only supported on sm_100a, sm_101a, sm_120a Full diff: https://github.com/llvm/llvm-project/pull/124899.diff 8 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string frag = Frag;
string ptx_elt_type = PtxEltType;
string gft = Geom#":"#Frag#":"#ptx_elt_type;
+ string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
// ldmatrix b16 -> s32 @ m8n8
- !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
- !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
- list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+ list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+ ldmatrix_geom_m16n16_ops,
+ ldmatrix_geom_m8n16_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
string g = frag.geom;
string t = frag.ptx_elt_type;
bit ret = !cond(
- // Only currently support m8n8 and b16
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
true: false
);
}
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
foreach transposed = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
def LDMATRIX_NAME<frag, transposed>.record
: NVVM_LDMATRIX<frag, transposed>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
// Explicit records for arch-accelerated SM versions
def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "b16") : Int32Regs,
+ !eq(ptx_elt_type, "b8") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
!eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b16"),
- !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+ !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
foreach space = [".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
addr>;
} // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_100a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_101a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..c1826fc561834e 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
"f64": "double",
"s32": "i32",
"b16": "i32",
+ "b8": "i32",
+ "b8x16.b6x16_p32" : "i32",
+ "b8x16.b4x16_p64" : "i32",
"s8": "i32",
"u8": "i32",
"s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n8:x1:b16": 1,
"m8n8:x2:b16": 2,
"m8n8:x4:b16": 4,
+ "m16n16:x1:b8": 2,
+ "m16n16:x2:b8": 4,
+ "m16n16:x1:b8x16.b6x16_p32": 2,
+ "m16n16:x2:b8x16.b6x16_p32": 4,
+ "m16n16:x1:b8x16.b4x16_p64": 2,
+ "m16n16:x2:b8x16.b4x16_p64": 4,
+ "m8n16:x1:b8x16.b6x16_p32" : 1,
+ "m8n16:x2:b8x16.b6x16_p32" : 2,
+ "m8n16:x4:b8x16.b6x16_p32" : 4,
+ "m8n16:x1:b8x16.b4x16_p64" : 1,
+ "m8n16:x2:b8x16.b4x16_p64" : 2,
+ "m8n16:x4:b8x16.b4x16_p64" : 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -289,7 +304,9 @@ def get_ldst_ops(kind):
def get_ldmatrix_ops():
- return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ return (make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ + make_ldmatrix_ops(["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"])
+ + make_ldmatrix_ops(["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]))
def is_wmma_geom_supported(geom):
@@ -330,8 +347,20 @@ def is_mma_geom_supported(geom):
def is_ldmatrix_geom_supported(geom):
if geom in ["m8n8"]:
return ptx_version >= 65 and gpu_arch >= 75
+ elif geom in ["m16n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ elif geom in ["m8n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
assert False # Unexpected geometry.
+def is_ldmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n16"]:
+ return trans == ".trans"
+ elif geom in ["m8n16"]:
+ return trans == ""
+ assert False # Unexpected geometry.
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +446,11 @@ def is_ldst_variant_supported(frag, layout):
return True
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
if not (
is_type_supported(frag.mma_type.ptx_type)
and is_ldmatrix_geom_supported(frag.geom)
+ and is_ldmatrix_trans_supported(frag.geom, trans)
):
return False
return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +683,7 @@ def gen_ldmatrix_tests():
["", ".shared"],
["", ".trans"],
):
- if not is_ldmatrix_variant_supported(frag):
+ if not is_ldmatrix_variant_supported(frag, trans):
continue
params = {
@@ -944,6 +974,19 @@ def gen_check_unsupported_ops(items):
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1040,16 @@ def gen_tests():
def main():
global ptx_version
global gpu_arch
+ global aa
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
+ parser.add_argument("--aa", action='store_true')
args = parser.parse_args()
ptx_version = args.ptx
gpu_arch = args.gpu_arch
+ aa = args.aa
gen_tests()
|
|
@llvm/pr-subscribers-llvm-ir Author: Pradeep Kumar (schwarzschild-radius) ChangesThis commit adds support for the following ldmatrix extensions introduced in PTX 8.6
The above extensions are only supported on sm_100a, sm_101a, sm_120a Full diff: https://github.com/llvm/llvm-project/pull/124899.diff 8 Files Affected:
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 9a2f38d760e659..f3aac47e4c4033 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
string frag = Frag;
string ptx_elt_type = PtxEltType;
string gft = Geom#":"#Frag#":"#ptx_elt_type;
+ string gf = Geom#":"#Frag;
string ft = frag#":"#ptx_elt_type;
list<LLVMType> regs = !cond(
// mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
// ldmatrix b16 -> s32 @ m8n8
- !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
- !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
- !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
+
+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
+
);
}
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
- list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
+
+ list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
+
+ list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
+ ldmatrix_geom_m16n16_ops,
+ ldmatrix_geom_m8n16_ops);
}
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
// def : FOO<>; // The record will only be defined for supported ops.
//
-class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
+class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
string g = frag.geom;
string t = frag.ptx_elt_type;
bit ret = !cond(
- // Only currently support m8n8 and b16
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
true: false
);
}
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
foreach transposed = [0, 1] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
def LDMATRIX_NAME<frag, transposed>.record
: NVVM_LDMATRIX<frag, transposed>;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index 773c97f7b4dc0f..4c1c5c10bfcc8b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v4i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::i32;
Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
- case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
+ case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
Info.opc = ISD::INTRINSIC_W_CHAIN;
Info.memVT = MVT::v2i32;
Info.ptrVal = I.getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
index 633a99d0fc1be3..d0a625643e2129 100644
--- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
+def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
// Explicit records for arch-accelerated SM versions
def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index 56d8b734bf01df..b2cf22b255f1d0 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!eq(ptx_elt_type, "tf32") : Int32Regs,
!eq(ptx_elt_type, "s32") : Int32Regs,
!eq(ptx_elt_type, "b16") : Int32Regs,
+ !eq(ptx_elt_type, "b8") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
+ !eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
!eq(ptx_elt_type, "s8") : Int32Regs,
!eq(ptx_elt_type, "u8") : Int32Regs,
!eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
!and(!eq(op,"ldmatrix"),
!eq(ptx_elt_type,"b16"),
- !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
+ !eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b6x16_p32"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
+
+ !and(!eq(op,"ldmatrix"),
+ !eq(ptx_elt_type,"b8x16.b4x16_p64"),
+ !eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
// template DAGs for instruction inputs/output.
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
foreach space = [".shared", ""] in {
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
addr>;
} // addr
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
new file mode 100644
index 00000000000000..6ad0a2a5865c41
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm100a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM100a
+# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_100a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_100a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
new file mode 100644
index 00000000000000..7d9953484da7d0
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm101a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM101a
+# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_101a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_101a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
new file mode 100644
index 00000000000000..7bddf0b6fbb785
--- /dev/null
+++ b/llvm/test/CodeGen/NVPTX/wmma-ptx86-sm120a.py
@@ -0,0 +1,16 @@
+# Check all variants of instructions supported by PTX86 on SM120a
+# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
+# RUN: --check-prefixes=PTX86LDMATRIX-DAG
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | FileCheck %t-ptx86-sm_120a.ll
+# RUN: %if ptxas-12.7 %{ \
+# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
+# RUN: | %ptxas-verify -arch=sm_120a \
+# RUN: %}
+
+import wmma
+
+wmma.main()
diff --git a/llvm/test/CodeGen/NVPTX/wmma.py b/llvm/test/CodeGen/NVPTX/wmma.py
index e1e46f0b8cab34..c1826fc561834e 100644
--- a/llvm/test/CodeGen/NVPTX/wmma.py
+++ b/llvm/test/CodeGen/NVPTX/wmma.py
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
"f64": "double",
"s32": "i32",
"b16": "i32",
+ "b8": "i32",
+ "b8x16.b6x16_p32" : "i32",
+ "b8x16.b4x16_p64" : "i32",
"s8": "i32",
"u8": "i32",
"s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
"m8n8:x1:b16": 1,
"m8n8:x2:b16": 2,
"m8n8:x4:b16": 4,
+ "m16n16:x1:b8": 2,
+ "m16n16:x2:b8": 4,
+ "m16n16:x1:b8x16.b6x16_p32": 2,
+ "m16n16:x2:b8x16.b6x16_p32": 4,
+ "m16n16:x1:b8x16.b4x16_p64": 2,
+ "m16n16:x2:b8x16.b4x16_p64": 4,
+ "m8n16:x1:b8x16.b6x16_p32" : 1,
+ "m8n16:x2:b8x16.b6x16_p32" : 2,
+ "m8n16:x4:b8x16.b6x16_p32" : 4,
+ "m8n16:x1:b8x16.b4x16_p64" : 1,
+ "m8n16:x2:b8x16.b4x16_p64" : 2,
+ "m8n16:x4:b8x16.b4x16_p64" : 4,
}.get(
"%s:%s:%s" % (geom, frag, ptx_elt_type),
{
@@ -289,7 +304,9 @@ def get_ldst_ops(kind):
def get_ldmatrix_ops():
- return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ return (make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
+ + make_ldmatrix_ops(["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"])
+ + make_ldmatrix_ops(["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]))
def is_wmma_geom_supported(geom):
@@ -330,8 +347,20 @@ def is_mma_geom_supported(geom):
def is_ldmatrix_geom_supported(geom):
if geom in ["m8n8"]:
return ptx_version >= 65 and gpu_arch >= 75
+ elif geom in ["m16n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
+ elif geom in ["m8n16"]:
+ return ptx_version >= 86 and gpu_arch >= 100 and aa
assert False # Unexpected geometry.
+def is_ldmatrix_trans_supported(geom, trans):
+ if geom in ["m8n8"]:
+ return True
+ elif geom in ["m16n16"]:
+ return trans == ".trans"
+ elif geom in ["m8n16"]:
+ return trans == ""
+ assert False # Unexpected geometry.
def is_type_supported(ptx_type):
if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +446,11 @@ def is_ldst_variant_supported(frag, layout):
return True
-def is_ldmatrix_variant_supported(frag):
+def is_ldmatrix_variant_supported(frag, trans):
if not (
is_type_supported(frag.mma_type.ptx_type)
and is_ldmatrix_geom_supported(frag.geom)
+ and is_ldmatrix_trans_supported(frag.geom, trans)
):
return False
return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +683,7 @@ def gen_ldmatrix_tests():
["", ".shared"],
["", ".trans"],
):
- if not is_ldmatrix_variant_supported(frag):
+ if not is_ldmatrix_variant_supported(frag, trans):
continue
params = {
@@ -944,6 +974,19 @@ def gen_check_unsupported_ops(items):
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
+; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
+
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1040,16 @@ def gen_tests():
def main():
global ptx_version
global gpu_arch
+ global aa
parser = argparse.ArgumentParser()
parser.add_argument("--ptx", type=int, default=60)
parser.add_argument("--gpu-arch", type=int, default=70)
+ parser.add_argument("--aa", action='store_true')
args = parser.parse_args()
ptx_version = args.ptx
gpu_arch = args.gpu_arch
+ aa = args.aa
gen_tests()
|
|
✅ With the latest revision this PR passed the Python code formatter. |
cfabf9c to
52f3619
Compare
Artem-B
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Sorry about the delay. I apparently archived PR notification out of inbox before reading it.
In case you see no response from me in reasonable time in the future, feel free to ping me to bring it to my attention. Github should have some sort of opt-in auto-nag feature for assigned reviewers.
52f3619 to
6e54665
Compare
No Problem, Artem. Sorry for delay in responding. For some reason, GitHub stopped sending PR update notifications on the mobile app. I'll be more proactive with reviews going forward. |
|
@Artem-B I have addressed all your comments. Please let me know if there's anything else needed or if I can proceed with merging the PR. Thanks! |
Artem-B
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a small nit.
This commit adds support for the following ldmatrix extensions introduced in PTX 8.6 - Support for m16n16 with b8 type with mandatory transpose - Support for m16n16 with m8n16 with source and desitination formats The above extensions are only supported on sm_100a, sm_101a, sm_120a Please refer the PTX ISA for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
6e54665 to
6f964b6
Compare
This commit adds support for the following ldmatrix extensions introduced in PTX 8.6
The above extensions are only supported on sm_100a, sm_101a, sm_120a
Please refer the PTX ISA for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix