@@ -331,6 +331,11 @@ class WMMA_REGS<string Geom, string Frag, string PtxEltType> {
331331 !eq(gf,"m8n16:x2") : !listsplat(llvm_i32_ty, 2),
332332 !eq(gf,"m8n16:x4") : !listsplat(llvm_i32_ty, 4),
333333
334+ // stmatrix b8 -> s32 @ m16n8
335+ !eq(gf,"m16n8:x1") : !listsplat(llvm_i32_ty, 1),
336+ !eq(gf,"m16n8:x2") : !listsplat(llvm_i32_ty, 2),
337+ !eq(gf,"m16n8:x4") : !listsplat(llvm_i32_ty, 4),
338+
334339 );
335340}
336341
@@ -403,6 +408,17 @@ class LDMATRIX_NAME<WMMA_REGS Frag, int Trans> {
403408 !subst("llvm.", "int_", intr));
404409}
405410
411+ class STMATRIX_NAME<WMMA_REGS Frag, int Trans> {
412+ string intr = "llvm.nvvm.stmatrix.sync.aligned"
413+ # "." # Frag.geom
414+ # "." # Frag.frag
415+ # !if(Trans, ".trans", "")
416+ # "." # Frag.ptx_elt_type
417+ ;
418+ string record = !subst(".", "_",
419+ !subst("llvm.", "int_", intr));
420+ }
421+
406422// Generates list of 4-tuples of WMMA_REGS representing a valid MMA op.
407423// Geom: list of supported geometries.
408424// TypeN: PTX type of the corresponding fragment's element.
@@ -443,6 +459,16 @@ class LDMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
443459 list<string> ops = !foreach(x, ret, x.gft);
444460}
445461
462+ class STMATRIX_OPS<list<string> Geom, list<string> Frags, list<string> Types> {
463+ list<WMMA_REGS> ret =
464+ !foldl([]<WMMA_REGS>, Geom, t1, geom, !listconcat(t1,
465+ !foldl([]<WMMA_REGS>, Frags, t2, frag, !listconcat(t2,
466+ !foldl([]<WMMA_REGS>, Types, t3, type, !listconcat(t3,
467+ [WMMA_REGS<geom, frag, type>]))))));
468+ // Debugging aid for readable representation of the list above.
469+ list<string> ops = !foreach(x, ret, x.gft);
470+ }
471+
446472// Creates list of valid combinations of fragments. This is the main list that
447473// drives generation of corresponding intrinsics and instructions.
448474class NVVM_MMA_OPS {
@@ -537,9 +563,18 @@ class NVVM_MMA_OPS {
537563 list<WMMA_REGS> ldmatrix_geom_m8n16_ops = LDMATRIX_OPS<
538564 ["m8n16"], ["x1", "x2", "x4"], ["b8x16.b6x16_p32", "b8x16.b4x16_p64"]>.ret;
539565
566+ list<WMMA_REGS> stmatrix_b16_ops = STMATRIX_OPS<
567+ ["m8n8"], ["x1", "x2", "x4"], ["b16"]>.ret;
568+
569+ list<WMMA_REGS> stmatrix_b8_ops = STMATRIX_OPS<
570+ ["m16n8"], ["x1", "x2", "x4"], ["b8"]>.ret;
571+
540572 list<WMMA_REGS> all_ldmatrix_ops = !listconcat(ldmatrix_b16_ops,
541573 ldmatrix_geom_m16n16_ops,
542574 ldmatrix_geom_m8n16_ops);
575+
576+ list<WMMA_REGS> all_stmatrix_ops = !listconcat(stmatrix_b16_ops,
577+ stmatrix_b8_ops);
543578}
544579
545580def NVVM_MMA_OPS : NVVM_MMA_OPS;
@@ -680,6 +715,19 @@ class NVVM_LDMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
680715 );
681716}
682717
718+ // Returns true if the fragment is valid for stmatrix ops is supported;
719+ // false otherwise.
720+ class NVVM_STMATRIX_SUPPORTED<WMMA_REGS frag, bit trans> {
721+ string g = frag.geom;
722+ string t = frag.ptx_elt_type;
723+
724+ bit ret = !cond(
725+ !and(!eq(g, "m8n8"), !eq(t, "b16")): true,
726+ !and(!eq(g, "m16n8"), !eq(t, "b8"), !eq(trans, 1)): true,
727+ true: false
728+ );
729+ }
730+
683731class SHFL_INFO<bit sync, string mode, string type, bit return_pred> {
684732 string Suffix = !if(sync, "sync_", "")
685733 # mode # "_"
@@ -1969,6 +2017,23 @@ foreach transposed = [0, 1] in {
19692017 }
19702018}
19712019
2020+ // STMATRIX
2021+ class NVVM_STMATRIX<WMMA_REGS Frag, int Transposed>
2022+ : Intrinsic<[],
2023+ !listconcat([llvm_anyptr_ty], Frag.regs),
2024+ [IntrWriteMem, IntrArgMemOnly, IntrNoCallback,
2025+ WriteOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>],
2026+ STMATRIX_NAME<Frag, Transposed>.intr>;
2027+
2028+ foreach transposed = [0, 1] in {
2029+ foreach frag = NVVM_MMA_OPS.all_stmatrix_ops in {
2030+ if NVVM_STMATRIX_SUPPORTED<frag, transposed>.ret then {
2031+ def STMATRIX_NAME<frag, transposed>.record
2032+ : NVVM_STMATRIX<frag, transposed>;
2033+ }
2034+ }
2035+ }
2036+
19722037// MAPA
19732038let IntrProperties = [IntrNoMem, IntrSpeculatable, NoCapture<ArgIndex<0>>] in {
19742039 def int_nvvm_mapa
0 commit comments