@@ -72,6 +72,7 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
7272 string frag = Frag;
7373 string ptx_elt_type = PtxEltType;
7474 string gft = Geom#":"#Frag#":"#ptx_elt_type;
75+ string gf = Geom#":"#Frag;
7576 string ft = frag#":"#ptx_elt_type;
7677 list<LLVMType> regs = !cond(
7778 // mma fp ops use smaller fragments than wmma fp ops
@@ -214,9 +215,19 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
214215 !eq(gft,"m16n8k256:d:s32") : !listsplat(llvm_i32_ty, 4),
215216
216217 // ldmatrix b16 -> s32 @ m8n8
217- !eq(gft,"m8n8:x1:b16") : !listsplat(llvm_i32_ty, 1),
218- !eq(gft,"m8n8:x2:b16") : !listsplat(llvm_i32_ty, 2),
219- !eq(gft,"m8n8:x4:b16") : !listsplat(llvm_i32_ty, 4),
218+ !eq(gf,"m8n8:x1") : !listsplat(llvm_i32_ty, 1),
219+ !eq(gf,"m8n8:x2") : !listsplat(llvm_i32_ty, 2),
220+ !eq(gf,"m8n8:x4") : !listsplat(llvm_i32_ty, 4),
221+
222+ // ldmatrix b8, b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m16n16
223+ !eq(gf,"m16n16:x1") : !listsplat(llvm_i32_ty, 2),
224+ !eq(gf,"m16n16:x2") : !listsplat(llvm_i32_ty, 4),
225+
226+ // ldmatrix b8x16.b6x16_p32, b8x16.b4x16_p64 -> s32 @ m8n16
227+ !eq(gf,"m8n16:x1") : !listsplat(llvm_i32_ty, 1),
228+ !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
229+ !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
230+
220231 );
221232}
222233
@@ -421,7 +432,16 @@ class NVVM_MMA_OPS {
421432
422433 list<WMMA_REGS> ldmatrix_b16_ops = LDMATRIX_OPS<
423434 ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
424- list<WMMA_REGS> all_ldmatrix_ops = ldmatrix_b16_ops;
435+
436+ list<WMMA_REGS> ldmatrix_geom_m16n16_ops = LDMATRIX_OPS<
437+ ["m16n16"], ["x1", "x2"], ["b8", "b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
438+
439+ list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
440+ ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
441+
442+ list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
443+ ldmatrix_geom_m16n16_ops,
444+ ldmatrix_geom_m8n16_ops);
425445}
426446
427447def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -546,13 +566,18 @@ class NVVM_MMA_SUPPORTED<list<WMMA_REGS> frags, string layout_a, string layout_b
546566// if NVVM_LDMATRIX_SUPPORTED<...>.ret then
547567// def : FOO<>; // The record will only be defined for supported ops.
548568//
549- class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag> {
569+ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans > {
550570 string g = frag.geom;
551571 string t = frag.ptx_elt_type;
552572
553573 bit ret = !cond(
554- // Only currently support m8n8 and b16
555574 !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
575+ !and(!eq(g, "m16n16"), !eq(t, "b8"), !eq(trans, 1)): true,
576+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 1)): true,
577+ !and(!eq(g, "m16n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 1)): true,
578+ !and(!eq(g, "m8n16"), !eq(t, "b8"), !eq(trans, 0)): true,
579+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b6x16_p32"), !eq(trans, 0)): true,
580+ !and(!eq(g, "m8n16"), !eq(t, "b8x16.b4x16_p64"), !eq(trans, 0)): true,
556581 true: false
557582 );
558583}
@@ -4983,7 +5008,7 @@ class NVVM_LDMATRIX<WMMA_REGS Frag, int Transposed>
49835008
49845009foreach transposed = [0, 1] in {
49855010 foreach frag = NVVM_MMA_OPS.all_ldmatrix_ops in {
4986- if NVVM_LDMATRIX_SUPPORTED<frag>.ret then {
5011+ if NVVM_LDMATRIX_SUPPORTED<frag, transposed >.ret then {
49875012 def LDMATRIX_NAME<frag, transposed>.record
49885013 : NVVM_LDMATRIX<frag, transposed>;
49895014 }
0 commit comments