@@ -153,6 +153,19 @@ struct DispatchInstruction;
153153
154154using  _X = Underscore;
155155
156+ template  <int  num_warp_m, int  num_warp_n, int  N>
157+ struct  DispatchInstruction <fp8_e4_t , fp8_e4_t , float , num_warp_m, num_warp_n,
158+                            N> {
159+   using  MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
160+   using  MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
161+ };
162+ template  <int  num_warp_m, int  num_warp_n, int  N>
163+ struct  DispatchInstruction <fp8_e5_t , fp8_e5_t , float , num_warp_m, num_warp_n,
164+                            N> {
165+   using  MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
166+   using  MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16 , N)>, _X>;
167+ };
168+ 
156169#if  (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
157170template  <int  num_warp_m, int  num_warp_n, int  N>
158171struct  DispatchInstruction <half_t , half_t , half_t , num_warp_m, num_warp_n, N> {
@@ -533,55 +546,56 @@ class GemmTensorOp {
533546
534547} //  namespace tl_mma
535548
536- } /* *
537-  * Execute a tiled GEMM where both A and B tiles are sourced from shared memory. 
538-  * 
539-  * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body to perform the computation. 
540-  * 
541-  * @param pA Pointer to the A tile region (device memory). 
542-  * @param pB Pointer to the B tile region (device memory). 
543-  * @param accum Pointer to the accumulator/output tile region (device memory). 
544-  */  
549+ } //  namespace cute
545550/* *
546-  * Execute a tiled GEMM where A is read from global memory and B is staged in shared memory. 
551+  * Execute a tiled GEMM where A is read from global memory and B is staged in 
552+  * shared memory. 
547553 * 
548-  * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the computation. 
554+  * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the 
555+  * computation. 
549556 * 
550557 * @param pA Pointer to the A tile region (device memory). 
551558 * @param pB Pointer to the B tile region (device memory). 
552559 * @param accum Pointer to the accumulator/output tile region (device memory). 
553560 */  
554561/* *
555-  * Execute a tiled GEMM where A is staged in shared memory and B is read from global memory. 
562+  * Execute a tiled GEMM where A is staged in shared memory and B is read from 
563+  * global memory. 
556564 * 
557-  * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the computation. 
565+  * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the 
566+  * computation. 
558567 * 
559568 * @param pA Pointer to the A tile region (device memory). 
560569 * @param pB Pointer to the B tile region (device memory). 
561570 * @param accum Pointer to the accumulator/output tile region (device memory). 
562571 */  
563572/* *
564-  * Perform a tiled GEMM (both operands in shared memory or selected backend) and write to accum. 
573+  * Perform a tiled GEMM (both operands in shared memory or selected backend) and 
574+  * write to accum. 
565575 * 
566-  * If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to 
567-  * the Hopper wgmma implementation; otherwise dispatches to the tl_mma implementation. 
576+  * If use_wgmma is true, validates wgmma constraints (strides and offsets) and 
577+  * dispatches to the Hopper wgmma implementation; otherwise dispatches to the 
578+  * tl_mma implementation. 
568579 * 
569580 * @param pA Pointer to the A tile region (device memory). 
570581 * @param pB Pointer to the B tile region (device memory). 
571582 * @param accum Pointer to the accumulator/output tile region (device memory). 
572583 */  
573584/* *
574-  * Perform a tiled GEMM with A in global memory and B in shared memory (or selected backend). 
585+  * Perform a tiled GEMM with A in global memory and B in shared memory (or 
586+  * selected backend). 
575587 * 
576-  * If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to 
577-  * the Hopper wgmma read-share implementation; otherwise dispatches to the tl_mma read-share. 
588+  * If use_wgmma is true, validates wgmma constraints (strides and offsets) and 
589+  * dispatches to the Hopper wgmma read-share implementation; otherwise 
590+  * dispatches to the tl_mma read-share. 
578591 * 
579592 * @param pA Pointer to the A tile region (device memory). 
580593 * @param pB Pointer to the B tile region (device memory). 
581594 * @param accum Pointer to the accumulator/output tile region (device memory). 
582595 */  
583596/* *
584-  * Perform a tiled GEMM with A staged in shared memory and B in global memory (tl_mma only). 
597+  * Perform a tiled GEMM with A staged in shared memory and B in global memory 
598+  * (tl_mma only). 
585599 * 
586600 * wgmma does not support this variant; caller must set use_wgmma == false. 
587601 * Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr. 
@@ -601,16 +615,19 @@ class GemmTensorOp {
601615 * Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id. 
602616 */  
603617/* *
604-  * Arrive at a named barrier for NumMmaThreads MMA threads using architecture-aware mapping. 
618+  * Arrive at a named barrier for NumMmaThreads MMA threads using 
619+  * architecture-aware mapping. 
605620 * 
606-  * Supported NumMmaThreads values: 256 or 384. The function issues one or two barrier arrives 
607-  * depending on the thread-group topology to ensure proper rendezvous ordering. 
621+  * Supported NumMmaThreads values: 256 or 384. The function issues one or two 
622+  * barrier arrives depending on the thread-group topology to ensure proper 
623+  * rendezvous ordering. 
608624 */  
609625/* *
610626 * Initialize named-barrier state for multi-warp MMA execution. 
611627 * 
612-  * For NumMmaThreads == 256 or 384, performs the required initial barrier arrivals for 
613-  * non-zero canonical warp-group indices to set up subsequent barrier synchronization. 
628+  * For NumMmaThreads == 256 or 384, performs the required initial barrier 
629+  * arrivals for non-zero canonical warp-group indices to set up subsequent 
630+  * barrier synchronization. 
614631 */  
615632
616633namespace  tl  {
@@ -682,22 +699,29 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
682699          int  offset_a = 0 , int  offset_b = 0 , bool  use_wgmma = true ,
683700          int  wg_wait = 0 , typename  A_type, typename  B_type, typename  C_type>
684701TL_DEVICE /* *
685-  * Perform a read-share (B in shared memory, A in global) tiled GEMM and accumulate into `accum`. 
686-  * 
687-  * Dispatches at compile time to either the Hopper wgmma implementation or the fallback MMA implementation 
688-  * depending on `use_wgmma`. The selected GemmTensorOp::body_rs performs the region-tiled GEMM loop and 
689-  * updates the accumulator in-place. 
690-  * 
691-  * When `use_wgmma == true`, this function enforces wgmma constraints at compile time: 
692-  * - A's leading dimension must equal (trans_A ? M : K) 
693-  * - B's leading dimension must equal (trans_B ? K : N) 
694-  * - offset_a and offset_b must be zero 
695-  * 
696-  * @param pA Pointer to operand A (global memory). Layout/stride expectations depend on template parameters. 
697-  * @param pB Pointer to operand B (base for shared-memory staging). Layout/stride expectations depend on template parameters. 
698-  * @param accum Pointer to the accumulator/output C buffer updated in-place. 
699-  */  
700- void  gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
702+            * Perform a read-share (B in shared memory, A in global) tiled GEMM 
703+            * and accumulate into `accum`. 
704+            * 
705+            * Dispatches at compile time to either the Hopper wgmma 
706+            * implementation or the fallback MMA implementation depending on 
707+            * `use_wgmma`. The selected GemmTensorOp::body_rs performs the 
708+            * region-tiled GEMM loop and updates the accumulator in-place. 
709+            * 
710+            * When `use_wgmma == true`, this function enforces wgmma constraints 
711+            * at compile time: 
712+            * - A's leading dimension must equal (trans_A ? M : K) 
713+            * - B's leading dimension must equal (trans_B ? K : N) 
714+            * - offset_a and offset_b must be zero 
715+            * 
716+            * @param pA Pointer to operand A (global memory). Layout/stride 
717+            * expectations depend on template parameters. 
718+            * @param pB Pointer to operand B (base for shared-memory staging). 
719+            * Layout/stride expectations depend on template parameters. 
720+            * @param accum Pointer to the accumulator/output C buffer updated 
721+            * in-place. 
722+            */  
723+     void 
724+     gemm_rs (A_type *pA, B_type *pB, C_type *accum) {
701725  if  constexpr  (use_wgmma) {
702726    static_assert ((trans_A && lda == M) || (!trans_A && lda == K),
703727                  " Hopper wgmma doesn't support custom stride for A" 
@@ -723,17 +747,23 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
723747          int  offset_a = 0 , int  offset_b = 0 , bool  use_wgmma = true ,
724748          int  wg_wait = 0 , typename  A_type, typename  B_type, typename  C_type>
725749TL_DEVICE /* *
726-  * Perform a non-wgmma tiled GEMM where A regions are staged into shared memory 
727-  * and B is read directly from global memory, accumulating into `accum`. 
728-  * 
729-  * This overload dispatches to the tl_mma::GemmTensorOp::body_sr implementation. 
730-  * Must be instantiated with `use_wgmma = false` (enforced via static_assert). 
731-  * 
732-  * @param pA Pointer to the A operand in global memory (source that will be staged to shared memory). 
733-  * @param pB Pointer to the B operand in global memory (read directly). 
734-  * @param accum Pointer to the output accumulator matrix in global memory. 
735-  */  
736- void  gemm_sr (A_type *pA, B_type *pB, C_type *accum) {
750+            * Perform a non-wgmma tiled GEMM where A regions are staged into 
751+            * shared memory and B is read directly from global memory, 
752+            * accumulating into `accum`. 
753+            * 
754+            * This overload dispatches to the tl_mma::GemmTensorOp::body_sr 
755+            * implementation. Must be instantiated with `use_wgmma = false` 
756+            * (enforced via static_assert). 
757+            * 
758+            * @param pA Pointer to the A operand in global memory (source that 
759+            * will be staged to shared memory). 
760+            * @param pB Pointer to the B operand in global memory (read 
761+            * directly). 
762+            * @param accum Pointer to the output accumulator matrix in global 
763+            * memory. 
764+            */  
765+     void 
766+     gemm_sr (A_type *pA, B_type *pB, C_type *accum) {
737767  static_assert (!use_wgmma, " wgmma doesn't support gemm_sr" 
738768  using  MMA =
739769      cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
@@ -742,13 +772,17 @@ void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
742772  MMA::body_sr (pA, pB, accum);
743773}
744774
745- template  <int  num_mma> TL_DEVICE /* *
746-  * Wait for all WMMA/MMA warps in the current warp-group to synchronize. 
747-  * 
748-  * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes completes, 
749-  * ensuring all participating warps have arrived before proceeding. 
750-  */  
751- void  wait_wgmma () {
775+ template  <int  num_mma>
776+ TL_DEVICE /* *
777+            * Wait for all WMMA/MMA warps in the current warp-group to 
778+            * synchronize. 
779+            * 
780+            * Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes 
781+            * completes, ensuring all participating warps have arrived before 
782+            * proceeding. 
783+            */  
784+     void 
785+     wait_wgmma () {
752786  cute::warpgroup_wait<num_mma>();
753787}
754788
0 commit comments