@@ -62,6 +62,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
6262 string frag = Frag;
6363 string ptx_elt_type = PtxEltType;
6464 string gft = Geom#":"#Frag#":"#ptx_elt_type;
65+ string gf = Geom#":"#Frag;
6566 string ft = frag#":"#ptx_elt_type;
6667 list<LLVMType> regs = !cond(
6768 // mma fp ops use smaller fragments than wmma fp ops
@@ -204,9 +205,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
204205 !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
205206
206207 // ldmatrix b16 -> s32 @ m8n8
207- !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
208- !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
209- !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
208+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
209+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
210+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
211+
212+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
213+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
214+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
215+
216+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
217+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
218+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
219+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
220+
210221 );
211222}
212223
@@ -411,7 +422,16 @@ class NVVM_MMA_OPS {
411422
412423 list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
413424 ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
414- list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
425+
426+ list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
427+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
428+
429+ list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
430+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
431+
432+ list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
433+ ldmatrix_geom_m16n16_ops,
434+ ldmatrix_geom_m8n16_ops);
415435}
416436
417437def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -536,13 +556,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
536556// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
537557// def : FOO<>; // The record will only be defined for supported ops.
538558//
539- class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
559+ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans > {
540560 string g = frag.geom;
541561 string t = frag.ptx_elt_type;
542562
543563 bit ret = !cond(
544- // Only currently support m8n8 and b16
545564 !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
565+ !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
566+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32")): true,
567+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64")): true,
568+ !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
569+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
570+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
546571 true: false
547572 );
548573}
@@ -4932,7 +4957,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
49324957
49334958foreach transposed = [0, 1] in {
49344959 foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
4935- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
4960+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed >.ret then {
49364961 def LDMATRIX_NAME<frag, transposed>.record
49374962 : NVVM_LDMATRIX<frag, transposed>;
49384963 }
0 commit comments