Skip to content

Commit cfabf9c

Browse files
[LLVM][NVPTX] Add support for ldmatrix extensions introduced in PTX 8.6
This commit adds support for the following ldmatrix extensions introduced in PTX 8.6 - Support for m16n16 with b8 type with mandatory transpose - Support for m16n16 with m8n16 with source and desitination formats The above extensions are only supported on sm_100a, sm_101a, sm_120a Please refer the PTX ISA for more information: https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
1 parent 2e43f39 commit cfabf9c

File tree

8 files changed

+176
-15
lines changed

8 files changed

+176
-15
lines changed

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
6262
string frag = Frag;
6363
string ptx_elt_type = PtxEltType;
6464
string gft = Geom#":"#Frag#":"#ptx_elt_type;
65+
string gf = Geom#":"#Frag;
6566
string ft = frag#":"#ptx_elt_type;
6667
list<LLVMType> regs = !cond(
6768
// mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
204205
!eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
205206

206207
// ldmatrix b16 -> s32 @ m8n8
207-
!eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
208-
!eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
209-
!eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
208+
!eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
209+
!eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
210+
!eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
211+
212+
// ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
213+
!eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
214+
!eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
215+
216+
// ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
217+
!eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
218+
!eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
219+
!eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
220+
210221
);
211222
}
212223

@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
411422

412423
list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
413424
["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
414-
list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
425+
426+
list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
427+
["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
428+
429+
list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
430+
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
431+
432+
list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
433+
ldmatrix_geom_m16n16_ops,
434+
ldmatrix_geom_m8n16_ops);
415435
}
416436

417437
def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
536556
// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
537557
// def : FOO<>; // The record will only be defined for supported ops.
538558
//
539-
class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
559+
class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
540560
string g = frag.geom;
541561
string t = frag.ptx_elt_type;
542562

543563
bit ret = !cond(
544-
// Only currently support m8n8 and b16
545564
!and(!eq(g, "m8n8"), !eq(t, "b16")): true,
565+
!and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
566+
!and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
567+
!and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
568+
!and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
569+
!and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
570+
!and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
546571
true: false
547572
);
548573
}
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
49324957

49334958
foreach transposed = [0, 1] in {
49344959
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
4935-
if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
4960+
if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then {
49364961
def LDMATRIX_NAME<frag, transposed>.record
49374962
: NVVM_LDMATRIX<frag, transposed>;
49384963
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3552,7 +3552,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
35523552
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row:
35533553
case Intrinsic::nvvm_wmma_m16n16k8_load_b_tf32_row_stride:
35543554
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_b16:
3555-
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16: {
3555+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x4_trans_b16:
3556+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8:
3557+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b4x16_p64:
3558+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x2_trans_b8x16_b6x16_p32:
3559+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b4x16_p64:
3560+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x4_b8x16_b6x16_p32: {
35563561
Info.opc = ISD::INTRINSIC_W_CHAIN;
35573562
Info.memVT = MVT::v4i32;
35583563
Info.ptrVal = I.getArgOperand(0);
@@ -3592,7 +3597,9 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
35923597
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col_stride:
35933598
case Intrinsic::nvvm_wmma_m8n8k32_load_b_u4_col:
35943599
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_b16:
3595-
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16: {
3600+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x1_trans_b16:
3601+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b4x16_p64:
3602+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x1_b8x16_b6x16_p32: {
35963603
Info.opc = ISD::INTRINSIC_W_CHAIN;
35973604
Info.memVT = MVT::i32;
35983605
Info.ptrVal = I.getArgOperand(0);
@@ -3688,7 +3695,12 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
36883695
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row:
36893696
case Intrinsic::nvvm_wmma_m8n8k32_load_c_s32_row_stride:
36903697
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_b16:
3691-
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16: {
3698+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n8_x2_trans_b16:
3699+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8:
3700+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b4x16_p64:
3701+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m16n16_x1_trans_b8x16_b6x16_p32:
3702+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b4x16_p64:
3703+
case Intrinsic::nvvm_ldmatrix_sync_aligned_m8n16_x2_b8x16_b6x16_p32: {
36923704
Info.opc = ISD::INTRINSIC_W_CHAIN;
36933705
Info.memVT = MVT::v2i32;
36943706
Info.ptrVal = I.getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def False : Predicate<"false">;
170170
class hasPTX<int version>: Predicate<"Subtarget->getPTXVersion() >= " # version>;
171171
class hasSM<int version>: Predicate<"Subtarget->getSmVersion() >= " # version>;
172172

173+
def hasAAFeatures : Predicate<"Subtarget->hasAAFeatures()">;
173174
// Explicit records for arch-accelerated SM versions
174175
def hasSM90a : Predicate<"Subtarget->getFullSmVersion() == 901">;
175176
def hasSM100a : Predicate<"Subtarget->getFullSmVersion() == 1001">;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7107,6 +7107,9 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
71077107
!eq(ptx_elt_type, "tf32") : Int32Regs,
71087108
!eq(ptx_elt_type, "s32") : Int32Regs,
71097109
!eq(ptx_elt_type, "b16") : Int32Regs,
7110+
!eq(ptx_elt_type, "b8") : Int32Regs,
7111+
!eq(ptx_elt_type, "b8x16.b6x16_p32") : Int32Regs,
7112+
!eq(ptx_elt_type, "b8x16.b4x16_p64") : Int32Regs,
71107113
!eq(ptx_elt_type, "s8") : Int32Regs,
71117114
!eq(ptx_elt_type, "u8") : Int32Regs,
71127115
!eq(ptx_elt_type, "s4") : Int32Regs,
@@ -7194,7 +7197,27 @@ class WMMA_REGINFO<WMMA_REGS r, string op>
71947197

71957198
!and(!eq(op,"ldmatrix"),
71967199
!eq(ptx_elt_type,"b16"),
7197-
!eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>]);
7200+
!eq(geom, "m8n8")) : [hasSM<75>, hasPTX<65>],
7201+
7202+
!and(!eq(op,"ldmatrix"),
7203+
!eq(ptx_elt_type,"b8"),
7204+
!eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
7205+
7206+
!and(!eq(op,"ldmatrix"),
7207+
!eq(ptx_elt_type,"b8x16.b6x16_p32"),
7208+
!eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
7209+
7210+
!and(!eq(op,"ldmatrix"),
7211+
!eq(ptx_elt_type,"b8x16.b4x16_p64"),
7212+
!eq(geom, "m16n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
7213+
7214+
!and(!eq(op,"ldmatrix"),
7215+
!eq(ptx_elt_type,"b8x16.b6x16_p32"),
7216+
!eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>],
7217+
7218+
!and(!eq(op,"ldmatrix"),
7219+
!eq(ptx_elt_type,"b8x16.b4x16_p64"),
7220+
!eq(geom, "m8n16")) : [hasSM<100>, hasAAFeatures, hasPTX<86>]);
71987221

71997222
// template DAGs for instruction inputs/output.
72007223
dag Outs = !dag(outs, ptx_regs, reg_names);
@@ -7478,7 +7501,7 @@ defset list<WMMA_INSTR> LDMATRIXs = {
74787501
foreach space = [".shared", ""] in {
74797502
foreach addr = [imem, Int32Regs, Int64Regs, MEMri, MEMri64] in {
74807503
foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in
7481-
if NVVM_LDMATRIX_SUPPORTED<frag>.ret then
7504+
if NVVM_LDMATRIX_SUPPORTED<frag, transposed>.ret then
74827505
def : LDMATRIX<WMMA_REGINFO<frag, "ldmatrix">, transposed, space,
74837506
addr>;
74847507
} // addr
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Check all variants of instructions supported by PTX86 on SM100a
2+
# RUN: %python %s --ptx=86 --gpu-arch=100 --aa > %t-ptx86-sm_100a.ll
3+
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5+
# RUN: FileCheck %t-ptx86-sm_100a.ll < %t-ptx86-sm_100a.ll \
6+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
7+
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
8+
# RUN: | FileCheck %t-ptx86-sm_100a.ll
9+
# RUN: %if ptxas-12.7 %{ \
10+
# RUN: llc < %t-ptx86-sm_100a.ll -mtriple=nvptx64 -mcpu=sm_100a -mattr=+ptx86 \
11+
# RUN: | %ptxas-verify -arch=sm_100a \
12+
# RUN: %}
13+
14+
import wmma
15+
16+
wmma.main()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Check all variants of instructions supported by PTX86 on SM101a
2+
# RUN: %python %s --ptx=86 --gpu-arch=101 --aa > %t-ptx86-sm_101a.ll
3+
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5+
# RUN: FileCheck %t-ptx86-sm_101a.ll < %t-ptx86-sm_101a.ll \
6+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
7+
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
8+
# RUN: | FileCheck %t-ptx86-sm_101a.ll
9+
# RUN: %if ptxas-12.7 %{ \
10+
# RUN: llc < %t-ptx86-sm_101a.ll -mtriple=nvptx64 -mcpu=sm_101a -mattr=+ptx86 \
11+
# RUN: | %ptxas-verify -arch=sm_101a \
12+
# RUN: %}
13+
14+
import wmma
15+
16+
wmma.main()
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Check all variants of instructions supported by PTX86 on SM120a
2+
# RUN: %python %s --ptx=86 --gpu-arch=120 --aa > %t-ptx86-sm_120a.ll
3+
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
4+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
5+
# RUN: FileCheck %t-ptx86-sm_120a.ll < %t-ptx86-sm_120a.ll \
6+
# RUN: --check-prefixes=PTX86LDMATRIX-DAG
7+
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
8+
# RUN: | FileCheck %t-ptx86-sm_120a.ll
9+
# RUN: %if ptxas-12.7 %{ \
10+
# RUN: llc < %t-ptx86-sm_120a.ll -mtriple=nvptx64 -mcpu=sm_120a -mattr=+ptx86 \
11+
# RUN: | %ptxas-verify -arch=sm_120a \
12+
# RUN: %}
13+
14+
import wmma
15+
16+
wmma.main()

llvm/test/CodeGen/NVPTX/wmma.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def __init__(self, ptx_type):
1919
"f64": "double",
2020
"s32": "i32",
2121
"b16": "i32",
22+
"b8": "i32",
23+
"b8x16.b6x16_p32": "i32",
24+
"b8x16.b4x16_p64": "i32",
2225
"s8": "i32",
2326
"u8": "i32",
2427
"s4": "i32",
@@ -161,6 +164,18 @@ def __init__(self, geom, frag, ptx_elt_type):
161164
"m8n8:x1:b16": 1,
162165
"m8n8:x2:b16": 2,
163166
"m8n8:x4:b16": 4,
167+
"m16n16:x1:b8": 2,
168+
"m16n16:x2:b8": 4,
169+
"m16n16:x1:b8x16.b6x16_p32": 2,
170+
"m16n16:x2:b8x16.b6x16_p32": 4,
171+
"m16n16:x1:b8x16.b4x16_p64": 2,
172+
"m16n16:x2:b8x16.b4x16_p64": 4,
173+
"m8n16:x1:b8x16.b6x16_p32": 1,
174+
"m8n16:x2:b8x16.b6x16_p32": 2,
175+
"m8n16:x4:b8x16.b6x16_p32": 4,
176+
"m8n16:x1:b8x16.b4x16_p64": 1,
177+
"m8n16:x2:b8x16.b4x16_p64": 2,
178+
"m8n16:x4:b8x16.b4x16_p64": 4,
164179
}.get(
165180
"%s:%s:%s" % (geom, frag, ptx_elt_type),
166181
{
@@ -289,7 +304,15 @@ def get_ldst_ops(kind):
289304

290305

291306
def get_ldmatrix_ops():
292-
return make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
307+
return (
308+
make_ldmatrix_ops(["m8n8"], ["x1", "x2", "x4"], ["b16"])
309+
+ make_ldmatrix_ops(
310+
["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]
311+
)
312+
+ make_ldmatrix_ops(
313+
["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]
314+
)
315+
)
293316

294317

295318
def is_wmma_geom_supported(geom):
@@ -330,8 +353,20 @@ def is_mma_geom_supported(geom):
330353
def is_ldmatrix_geom_supported(geom):
331354
if geom in ["m8n8"]:
332355
return ptx_version >= 65 and gpu_arch >= 75
356+
elif geom in ["m16n16"]:
357+
return ptx_version >= 86 and gpu_arch >= 100 and aa
358+
elif geom in ["m8n16"]:
359+
return ptx_version >= 86 and gpu_arch >= 100 and aa
333360
assert False # Unexpected geometry.
334361

362+
def is_ldmatrix_trans_supported(geom, trans):
363+
if geom in ["m8n8"]:
364+
return True
365+
elif geom in ["m16n16"]:
366+
return trans == ".trans"
367+
elif geom in ["m8n16"]:
368+
return trans == ""
369+
assert False # Unexpected geometry.
335370

336371
def is_type_supported(ptx_type):
337372
if ptx_type in ["s8", "u8", "s32"]:
@@ -417,10 +452,11 @@ def is_ldst_variant_supported(frag, layout):
417452
return True
418453

419454

420-
def is_ldmatrix_variant_supported(frag):
455+
def is_ldmatrix_variant_supported(frag, trans):
421456
if not (
422457
is_type_supported(frag.mma_type.ptx_type)
423458
and is_ldmatrix_geom_supported(frag.geom)
459+
and is_ldmatrix_trans_supported(frag.geom, trans)
424460
):
425461
return False
426462
return frag.frag in ["x1", "x2", "x4"]
@@ -653,7 +689,7 @@ def gen_ldmatrix_tests():
653689
["", ".shared"],
654690
["", ".trans"],
655691
):
656-
if not is_ldmatrix_variant_supported(frag):
692+
if not is_ldmatrix_variant_supported(frag, trans):
657693
continue
658694

659695
params = {
@@ -944,6 +980,19 @@ def gen_check_unsupported_ops(items):
944980
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16
945981
; PTX65LDMATRIX-DAG: ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16
946982
983+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8
984+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8
985+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b6x16_p32
986+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x1.trans.b8x16.b4x16_p64
987+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b6x16_p32
988+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m16n16.x2.trans.b8x16.b4x16_p64
989+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b6x16_p32
990+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x1.b8x16.b4x16_p64
991+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b6x16_p32
992+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x2.b8x16.b4x16_p64
993+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b6x16_p32
994+
; PTX86LDMATRIX-DAG: ldmatrix.sync.aligned.m8n16.x4.b8x16.b4x16_p64
995+
947996
; PTX71MMA-DAG: mma.m8n8k4.row.col.f64
948997
; PTX71MMA-DAG: mma.m16n8k4.row.col.tf32
949998
; PTX71MMA-DAG: mma.m16n8k8.row.col.tf32
@@ -997,13 +1046,16 @@ def gen_tests():
9971046
def main():
9981047
global ptx_version
9991048
global gpu_arch
1049+
global aa
10001050
parser = argparse.ArgumentParser()
10011051
parser.add_argument("--ptx", type=int, default=60)
10021052
parser.add_argument("--gpu-arch", type=int, default=70)
1053+
parser.add_argument("--aa", action="store_true")
10031054
args = parser.parse_args()
10041055

10051056
ptx_version = args.ptx
10061057
gpu_arch = args.gpu_arch
1058+
aa = args.aa
10071059

10081060
gen_tests()
10091061

0 commit comments

Comments
 (0)