Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 0 additions & 49 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,53 +240,4 @@ template <int barrier_id = 0, int thread_count = 0>
TL_DEVICE void __sync_thread_partial() {
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
}

// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
// thread.
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {

// Special case: thread_extent == 0 means "elect exactly one thread
// in the entire thread block", i.e., the leader of the first warp of the
// block.
if constexpr (thread_extent == 0) {
// cutlass::canonical_warp_idx_sync():
// Returns the warp ID within the thread block in a "canonical" way
// (0 for the first warp, 1 for the second, ...).
// cute::elect_one_sync():
// Elect exactly one lane in the warp to return true (typically lane 0),
// other lanes return false.
// The condition ensures that:
// (1) We are in warp 0 of the block.
// (2) We are the elected lane in this warp.
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
}

// General case: thread_extent != 0
// (threadIdx.x / 32) is the warp index in the block.
// (thread_extent / 32) is the number of warps in one group of size
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
// within the group.
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
// lanes in the warp. Here it broadcasts the group-local warp index from lane
// 0. Comparing to 0 selects only the group's warp 0.
return __shfl_sync(0xffffffff, // full warp mask
(threadIdx.x / 32) %
(thread_extent / 32), // warp index within group
0 // take the value from lane 0
) == 0 &&
// Within that group leader warp, elect exactly one lane (typically
// lane 0) to be the single representative for the group.
cute::elect_one_sync();
}

template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}

template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}

} // namespace tl
157 changes: 96 additions & 61 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
#pragma once

#include "common.h"
#include "cuda_fp8.h"
#include "intrin.h"
#include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp>
#include <cutlass/arch/barrier.h>
#include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp>

#include "common.h"

namespace cute {

using namespace SM90;
Expand Down Expand Up @@ -153,6 +154,19 @@ struct DispatchInstruction;

using _X = Underscore;

template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e4_t, fp8_e4_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E4M3E4M3F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<fp8_e5_t, fp8_e5_t, float, num_warp_m, num_warp_n,
N> {
using MMA = MMA_Atom<SM89_16x8x32_F32E5M2E5M2F32_TN>;
using MMA_Group = Tile<_X, Int<std::min(num_warp_n * 16, N)>, _X>;
};

#if (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 800))
template <int num_warp_m, int num_warp_n, int N>
struct DispatchInstruction<half_t, half_t, half_t, num_warp_m, num_warp_n, N> {
Expand Down Expand Up @@ -533,55 +547,56 @@ class GemmTensorOp {

} // namespace tl_mma

} /**
* Execute a tiled GEMM where both A and B tiles are sourced from shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body to perform the computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
} // namespace cute
/**
* Execute a tiled GEMM where A is read from global memory and B is staged in shared memory.
* Execute a tiled GEMM where A is read from global memory and B is staged in
* shared memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the computation.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_rs to perform the
* computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Execute a tiled GEMM where A is staged in shared memory and B is read from global memory.
* Execute a tiled GEMM where A is staged in shared memory and B is read from
* global memory.
*
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the computation.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr to perform the
* computation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM (both operands in shared memory or selected backend) and write to accum.
* Perform a tiled GEMM (both operands in shared memory or selected backend) and
* write to accum.
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma implementation; otherwise dispatches to the tl_mma implementation.
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
* dispatches to the Hopper wgmma implementation; otherwise dispatches to the
* tl_mma implementation.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A in global memory and B in shared memory (or selected backend).
* Perform a tiled GEMM with A in global memory and B in shared memory (or
* selected backend).
*
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and dispatches to
* the Hopper wgmma read-share implementation; otherwise dispatches to the tl_mma read-share.
* If use_wgmma is true, validates wgmma constraints (strides and offsets) and
* dispatches to the Hopper wgmma read-share implementation; otherwise
* dispatches to the tl_mma read-share.
*
* @param pA Pointer to the A tile region (device memory).
* @param pB Pointer to the B tile region (device memory).
* @param accum Pointer to the accumulator/output tile region (device memory).
*/
/**
* Perform a tiled GEMM with A staged in shared memory and B in global memory (tl_mma only).
* Perform a tiled GEMM with A staged in shared memory and B in global memory
* (tl_mma only).
*
* wgmma does not support this variant; caller must set use_wgmma == false.
* Dispatches to tl_mma::GemmTensorOp<M,N,K,...>::body_sr.
Expand All @@ -601,16 +616,19 @@ class GemmTensorOp {
* Calls cutlass::arch::NamedBarrier::sync with the canonical warp-group id.
*/
/**
* Arrive at a named barrier for NumMmaThreads MMA threads using architecture-aware mapping.
* Arrive at a named barrier for NumMmaThreads MMA threads using
* architecture-aware mapping.
*
* Supported NumMmaThreads values: 256 or 384. The function issues one or two barrier arrives
* depending on the thread-group topology to ensure proper rendezvous ordering.
* Supported NumMmaThreads values: 256 or 384. The function issues one or two
* barrier arrives depending on the thread-group topology to ensure proper
* rendezvous ordering.
*/
/**
* Initialize named-barrier state for multi-warp MMA execution.
*
* For NumMmaThreads == 256 or 384, performs the required initial barrier arrivals for
* non-zero canonical warp-group indices to set up subsequent barrier synchronization.
* For NumMmaThreads == 256 or 384, performs the required initial barrier
* arrivals for non-zero canonical warp-group indices to set up subsequent
* barrier synchronization.
*/

namespace tl {
Expand Down Expand Up @@ -682,22 +700,29 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE /**
* Perform a read-share (B in shared memory, A in global) tiled GEMM and accumulate into `accum`.
*
* Dispatches at compile time to either the Hopper wgmma implementation or the fallback MMA implementation
* depending on `use_wgmma`. The selected GemmTensorOp::body_rs performs the region-tiled GEMM loop and
* updates the accumulator in-place.
*
* When `use_wgmma == true`, this function enforces wgmma constraints at compile time:
* - A's leading dimension must equal (trans_A ? M : K)
* - B's leading dimension must equal (trans_B ? K : N)
* - offset_a and offset_b must be zero
*
* @param pA Pointer to operand A (global memory). Layout/stride expectations depend on template parameters.
* @param pB Pointer to operand B (base for shared-memory staging). Layout/stride expectations depend on template parameters.
* @param accum Pointer to the accumulator/output C buffer updated in-place.
*/
void gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
* Perform a read-share (B in shared memory, A in global) tiled GEMM
* and accumulate into `accum`.
*
* Dispatches at compile time to either the Hopper wgmma
* implementation or the fallback MMA implementation depending on
* `use_wgmma`. The selected GemmTensorOp::body_rs performs the
* region-tiled GEMM loop and updates the accumulator in-place.
*
* When `use_wgmma == true`, this function enforces wgmma constraints
* at compile time:
* - A's leading dimension must equal (trans_A ? M : K)
* - B's leading dimension must equal (trans_B ? K : N)
* - offset_a and offset_b must be zero
*
* @param pA Pointer to operand A (global memory). Layout/stride
* expectations depend on template parameters.
* @param pB Pointer to operand B (base for shared-memory staging).
* Layout/stride expectations depend on template parameters.
* @param accum Pointer to the accumulator/output C buffer updated
* in-place.
*/
void
gemm_rs(A_type *pA, B_type *pB, C_type *accum) {
if constexpr (use_wgmma) {
static_assert((trans_A && lda == M) || (!trans_A && lda == K),
"Hopper wgmma doesn't support custom stride for A");
Expand All @@ -723,17 +748,23 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
int offset_a = 0, int offset_b = 0, bool use_wgmma = true,
int wg_wait = 0, typename A_type, typename B_type, typename C_type>
TL_DEVICE /**
* Perform a non-wgmma tiled GEMM where A regions are staged into shared memory
* and B is read directly from global memory, accumulating into `accum`.
*
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr implementation.
* Must be instantiated with `use_wgmma = false` (enforced via static_assert).
*
* @param pA Pointer to the A operand in global memory (source that will be staged to shared memory).
* @param pB Pointer to the B operand in global memory (read directly).
* @param accum Pointer to the output accumulator matrix in global memory.
*/
void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
* Perform a non-wgmma tiled GEMM where A regions are staged into
* shared memory and B is read directly from global memory,
* accumulating into `accum`.
*
* This overload dispatches to the tl_mma::GemmTensorOp::body_sr
* implementation. Must be instantiated with `use_wgmma = false`
* (enforced via static_assert).
*
* @param pA Pointer to the A operand in global memory (source that
* will be staged to shared memory).
* @param pB Pointer to the B operand in global memory (read
* directly).
* @param accum Pointer to the output accumulator matrix in global
* memory.
*/
void
gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
static_assert(!use_wgmma, "wgmma doesn't support gemm_sr");
using MMA =
cute::tl_mma::GemmTensorOp<M, N, K, num_warp_m, num_warp_n, trans_A,
Expand All @@ -742,13 +773,17 @@ void gemm_sr(A_type *pA, B_type *pB, C_type *accum) {
MMA::body_sr(pA, pB, accum);
}

template <int num_mma> TL_DEVICE /**
* Wait for all WMMA/MMA warps in the current warp-group to synchronize.
*
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes completes,
* ensuring all participating warps have arrived before proceeding.
*/
void wait_wgmma() {
template <int num_mma>
TL_DEVICE /**
* Wait for all WMMA/MMA warps in the current warp-group to
* synchronize.
*
* Blocks until the warp-group-wide rendezvous for `num_mma` MMA lanes
* completes, ensuring all participating warps have arrived before
* proceeding.
*/
void
wait_wgmma() {
cute::warpgroup_wait<num_mma>();
}

Expand Down
56 changes: 56 additions & 0 deletions src/tl_templates/cuda/intrin.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#pragma once

#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/cutlass.h"

Comment on lines +1 to +6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Make the header self-contained and guard architecture checks safely.

  • If included before common.h, TL_DEVICE is undefined. Provide fallbacks here to avoid order-dependence.
  • __CUDA_ARCH_LIST__ may be undefined. Use a defensive guard that also supports the standard __CUDA_ARCH__.
  • Add <stdint.h> to guarantee uint32_t availability.

Apply this diff at the top of the file:

 #pragma once

-#if __CUDA_ARCH_LIST__ >= 900
+#include <stdint.h>
+
+// Fallbacks if included before common.h
+#ifndef TL_DEVICE
+#define TL_DEVICE __forceinline__ __device__
+#endif
+#ifndef TL_DEVICE_NOINLINE
+#define TL_DEVICE_NOINLINE __noinline__ __device__
+#endif
+
+// Support both common project macro (__CUDA_ARCH_LIST__) and standard NVCC macro (__CUDA_ARCH__)
+#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || \
+    (defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
 #include "cute/arch/cluster_sm90.hpp"
 #include "cutlass/cutlass.h"
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#pragma once
#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/cutlass.h"
#pragma once
#include <stdint.h>
// Fallbacks if included before common.h
#ifndef TL_DEVICE
#define TL_DEVICE __forceinline__ __device__
#endif
#ifndef TL_DEVICE_NOINLINE
#define TL_DEVICE_NOINLINE __noinline__ __device__
#endif
// Support both common project macro (__CUDA_ARCH_LIST__) and standard NVCC macro (__CUDA_ARCH__)
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || \
(defined(__CUDA_ARCH_LIST__) && (__CUDA_ARCH_LIST__ >= 900))
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/cutlass.h"
🤖 Prompt for AI Agents
In src/tl_templates/cuda/intrin.h around lines 1 to 6, make the header
self-contained by adding defensive includes and guards at the top: include
<stdint.h> to ensure uint32_t is available, define a fallback for TL_DEVICE if
it's not already defined (e.g., a minimal definition or macro alias) so the
header does not depend on common.h, and replace the direct use of
__CUDA_ARCH_LIST__ with a safe check that falls back to testing __CUDA_ARCH__
(e.g., use #if defined(__CUDA_ARCH_LIST__) ? __CUDA_ARCH_LIST__ :
(defined(__CUDA_ARCH__) ? __CUDA_ARCH__ : 0)) before comparing against 900 to
avoid undefined macro errors; ensure these additions appear before any
architecture-specific includes.

namespace tl {
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
// thread.
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {

// Special case: thread_extent == 0 means "elect exactly one thread
// in the entire thread block", i.e., the leader of the first warp of the
// block.
if constexpr (thread_extent == 0) {
// cutlass::canonical_warp_idx_sync():
// Returns the warp ID within the thread block in a "canonical" way
// (0 for the first warp, 1 for the second, ...).
// cute::elect_one_sync():
// Elect exactly one lane in the warp to return true (typically lane 0),
// other lanes return false.
// The condition ensures that:
// (1) We are in warp 0 of the block.
// (2) We are the elected lane in this warp.
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
}

// General case: thread_extent != 0
// (threadIdx.x / 32) is the warp index in the block.
// (thread_extent / 32) is the number of warps in one group of size
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
// within the group.
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
// lanes in the warp. Here it broadcasts the group-local warp index from lane
// 0. Comparing to 0 selects only the group's warp 0.
return __shfl_sync(0xffffffff, // full warp mask
(threadIdx.x / 32) %
(thread_extent / 32), // warp index within group
0 // take the value from lane 0
) == 0 &&
// Within that group leader warp, elect exactly one lane (typically
// lane 0) to be the single representative for the group.
cute::elect_one_sync();
}
Comment on lines +30 to +46
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix correctness: modulo-by-zero risk and 3D block handling in group election.

  • (thread_extent / 32) can be 0 (e.g., 1..31), causing UB in the % expression.
  • Using threadIdx.x / 32 ignores threadIdx.y/z and is incorrect for 3D blocks. Prefer cutlass::canonical_warp_idx_sync() which is block-shape agnostic.
  • __shfl_sync is unnecessary here; the value is already uniform per warp.

The diff in my other comment rewrites this logic safely with a static_assert, uses canonical_warp_idx_sync(), and removes __shfl_sync.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/intrin.h around lines 29 to 45, the current
group-leader election uses (threadIdx.x / 32) % (thread_extent / 32) and
__shfl_sync which is unsafe for thread_extent < 32 (modulo-by-zero) and
incorrect for 3D blocks; replace this with a safe, block-shape-agnostic
approach: add a static_assert to ensure thread_extent is a multiple of 32,
compute the warp index using cutlass::canonical_warp_idx_sync() (or equivalent
canonical_warp index helper) rather than threadIdx.x/32, remove the unnecessary
__shfl_sync call (the warp-local value is already uniform), and then use
cute::elect_one_sync() within the chosen warp to elect a single representative
for the group.


template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}

template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
#endif
Loading