Skip to content

Commit 0933f4d

Browse files
committed
[Enhancement] Add DispatchInstruction specialization for fp8 types in gemm_sm90.h
- Introduced specialized DispatchInstruction templates for fp8_e4_t and fp8_e5_t types, enhancing support for new data formats in CUDA GEMM operations. - Each specialization defines the corresponding MMA and MMA_Group types, optimizing performance for specific configurations. Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 796b3bb commit 0933f4d

File tree

1 file changed

+93
-59
lines changed

1 file changed

+93
-59
lines changed

src/tl_templates/cuda/gemm_sm90.h

Lines changed: 93 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,19 @@ struct DispatchInstruction;
153153

154154
using _X = Underscore;
155155

156+
template <int num_warp_m, int num_warp_n, int N>
157+
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
158+
N> {
159+
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
160+
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
161+
};
162+
template <int num_warp_m, int num_warp_n, int N>
163+
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
164+
N> {
165+
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
166+
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
167+
};
168+
156169
#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
157170
template <int num_warp_m, int num_warp_n, int N>
158171
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
@@ -533,55 +546,56 @@ class GemmTensorOp {
533546

534547
} // namespace tl_mma
535548

536-
} /**
537-
* Execute a tiled GEMM where both A and B tiles are sourced from shared memory.
538-
*
539-
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body to perform the computation.
540-
*
541-
* @param pA Pointer to the A tile region (device memory).
542-
* @param pB Pointer to the B tile region (device memory).
543-
* @param accum Pointer to the accumulator/output tile region (device memory).
544-
*/
549+
} // namespace cute
545550
/**
546-
* Execute a tiled GEMM where A is read from global memory and B is staged in shared memory.
551+
* Execute a tiled GEMM where A is read from global memory and B is staged in
552+
* shared memory.
547553
*
548-
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the computation.
554+
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the
555+
* computation.
549556
*
550557
* @param pA Pointer to the A tile region (device memory).
551558
* @param pB Pointer to the B tile region (device memory).
552559
* @param accum Pointer to the accumulator/output tile region (device memory).
553560
*/
554561
/**
555-
* Execute a tiled GEMM where A is staged in shared memory and B is read from global memory.
562+
* Execute a tiled GEMM where A is staged in shared memory and B is read from
563+
* global memory.
556564
*
557-
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the computation.
565+
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the
566+
* computation.
558567
*
559568
* @param pA Pointer to the A tile region (device memory).
560569
* @param pB Pointer to the B tile region (device memory).
561570
* @param accum Pointer to the accumulator/output tile region (device memory).
562571
*/
563572
/**
564-
* Perform a tiled GEMM (both operands in shared memory or selected backend) and write to accum.
573+
* Perform a tiled GEMM (both operands in shared memory or selected backend) and
574+
* write to accum.
565575
*
566-
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
567-
* the Hopper wgmma implementation; otherwise dispatches to the tl_mma implementation.
576+
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
577+
* dispatches to the Hopper wgmma implementation; otherwise dispatches to the
578+
* tl_mma implementation.
568579
*
569580
* @param pA Pointer to the A tile region (device memory).
570581
* @param pB Pointer to the B tile region (device memory).
571582
* @param accum Pointer to the accumulator/output tile region (device memory).
572583
*/
573584
/**
574-
* Perform a tiled GEMM with A in global memory and B in shared memory (or selected backend).
585+
* Perform a tiled GEMM with A in global memory and B in shared memory (or
586+
* selected backend).
575587
*
576-
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
577-
* the Hopper wgmma read-share implementation; otherwise dispatches to the tl_mma read-share.
588+
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
589+
* dispatches to the Hopper wgmma read-share implementation; otherwise
590+
* dispatches to the tl_mma read-share.
578591
*
579592
* @param pA Pointer to the A tile region (device memory).
580593
* @param pB Pointer to the B tile region (device memory).
581594
* @param accum Pointer to the accumulator/output tile region (device memory).
582595
*/
583596
/**
584-
* Perform a tiled GEMM with A staged in shared memory and B in global memory (tl_mma only).
597+
* Perform a tiled GEMM with A staged in shared memory and B in global memory
598+
* (tl_mma only).
585599
*
586600
* wgmma does not support this variant; caller must set use_wgmma == false.
587601
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
@@ -601,16 +615,19 @@ class GemmTensorOp {
601615
* Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
602616
*/
603617
/**
604-
* Arrive at a named barrier for NumMmaThreads MMA threads using architecture-aware mapping.
618+
* Arrive at a named barrier for NumMmaThreads MMA threads using
619+
* architecture-aware mapping.
605620
*
606-
* Supported NumMmaThreads values: 256 or 384. The function issues one or two barrier arrives
607-
* depending on the thread-group topology to ensure proper rendezvous ordering.
621+
* Supported NumMmaThreads values: 256 or 384. The function issues one or two
622+
* barrier arrives depending on the thread-group topology to ensure proper
623+
* rendezvous ordering.
608624
*/
609625
/**
610626
* Initialize named-barrier state for multi-warp MMA execution.
611627
*
612-
* For NumMmaThreads == 256 or 384, performs the required initial barrier arrivals for
613-
* non-zero canonical warp-group indices to set up subsequent barrier synchronization.
628+
* For NumMmaThreads == 256 or 384, performs the required initial barrier
629+
* arrivals for non-zero canonical warp-group indices to set up subsequent
630+
* barrier synchronization.
614631
*/
615632

616633
namespace tl {
@@ -682,22 +699,29 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
682699
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
683700
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
684701
TL_DEVICE /**
685-
* Perform a read-share (B in shared memory, A in global) tiled GEMM and accumulate into `accum`.
686-
*
687-
* Dispatches at compile time to either the Hopper wgmma implementation or the fallback MMA implementation
688-
* depending on `use_wgmma`. The selected GemmTensorOp::body_rs performs the region-tiled GEMM loop and
689-
* updates the accumulator in-place.
690-
*
691-
* When `use_wgmma == true`, this function enforces wgmma constraints at compile time:
692-
* - A's leading dimension must equal (trans_A ? M : K)
693-
* - B's leading dimension must equal (trans_B ? K : N)
694-
* - offset_a and offset_b must be zero
695-
*
696-
* @param pA Pointer to operand A (global memory). Layout/stride expectations depend on template parameters.
697-
* @param pB Pointer to operand B (base for shared-memory staging). Layout/stride expectations depend on template parameters.
698-
* @param accum Pointer to the accumulator/output C buffer updated in-place.
699-
*/
700-
void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
702+
* Perform a read-share (B in shared memory, A in global) tiled GEMM
703+
* and accumulate into `accum`.
704+
*
705+
* Dispatches at compile time to either the Hopper wgmma
706+
* implementation or the fallback MMA implementation depending on
707+
* `use_wgmma`. The selected GemmTensorOp::body_rs performs the
708+
* region-tiled GEMM loop and updates the accumulator in-place.
709+
*
710+
* When `use_wgmma == true`, this function enforces wgmma constraints
711+
* at compile time:
712+
* - A's leading dimension must equal (trans_A ? M : K)
713+
* - B's leading dimension must equal (trans_B ? K : N)
714+
* - offset_a and offset_b must be zero
715+
*
716+
* @param pA Pointer to operand A (global memory). Layout/stride
717+
* expectations depend on template parameters.
718+
* @param pB Pointer to operand B (base for shared-memory staging).
719+
* Layout/stride expectations depend on template parameters.
720+
* @param accum Pointer to the accumulator/output C buffer updated
721+
* in-place.
722+
*/
723+
void
724+
gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
701725
if constexpr (use_wgmma) {
702726
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
703727
"Hopper wgmma doesn't support custom stride for A");
@@ -723,17 +747,23 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
723747
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
724748
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
725749
TL_DEVICE /**
726-
* Perform a non-wgmma tiled GEMM where A regions are staged into shared memory
727-
* and B is read directly from global memory, accumulating into `accum`.
728-
*
729-
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr implementation.
730-
* Must be instantiated with `use_wgmma = false` (enforced via static_assert).
731-
*
732-
* @param pA Pointer to the A operand in global memory (source that will be staged to shared memory).
733-
* @param pB Pointer to the B operand in global memory (read directly).
734-
* @param accum Pointer to the output accumulator matrix in global memory.
735-
*/
736-
void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
750+
* Perform a non-wgmma tiled GEMM where A regions are staged into
751+
* shared memory and B is read directly from global memory,
752+
* accumulating into `accum`.
753+
*
754+
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr
755+
* implementation. Must be instantiated with `use_wgmma = false`
756+
* (enforced via static_assert).
757+
*
758+
* @param pA Pointer to the A operand in global memory (source that
759+
* will be staged to shared memory).
760+
* @param pB Pointer to the B operand in global memory (read
761+
* directly).
762+
* @param accum Pointer to the output accumulator matrix in global
763+
* memory.
764+
*/
765+
void
766+
gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
737767
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
738768
using MMA =
739769
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
@@ -742,13 +772,17 @@ void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
742772
MMA::body_sr(pA, pB, accum);
743773
}
744774

745-
template <int num_mma> TL_DEVICE /**
746-
* Wait for all WMMA/MMA warps in the current warp-group to synchronize.
747-
*
748-
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes completes,
749-
* ensuring all participating warps have arrived before proceeding.
750-
*/
751-
void wait_wgmma() {
775+
template <int num_mma>
776+
TL_DEVICE /**
777+
* Wait for all WMMA/MMA warps in the current warp-group to
778+
* synchronize.
779+
*
780+
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes
781+
* completes, ensuring all participating warps have arrived before
782+
* proceeding.
783+
*/
784+
void
785+
wait_wgmma() {
752786
cute::warpgroup_wait<num_mma>();
753787
}
754788

0 commit comments

Comments
 (0)