Skip to content
Open
34 changes: 34 additions & 0 deletions src/tl_templates/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <cutlass/numeric_types.h>
#include <math_constants.h>

#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

Comment on lines +13 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.

Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.

Apply this diff near the existing cuda_runtime include:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_bf16.h>
 #endif

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.

using cutlass::bfloat16_t;
using cutlass::half_t;
using cutlass::tfloat32_t;
Expand Down Expand Up @@ -339,6 +342,37 @@ TL_DEVICE void increase_descriptor_offset(GmmaDescriptor &descriptor,
descriptor.reg32_[0] += (offset >> 4);
}

// and add the desired implicit conversion from bfloat16_t.
struct float_e4m3_t : public cute::float_e4m3_t {
using cute::float_e4m3_t::float_e4m3_t;
CUTLASS_HOST_DEVICE
float_e4m3_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e4m3_t(__nv_bfloat16 x)
: float_e4m3_t(static_cast<float>(x)) {}
};

struct float_e5m2_t : public cute::float_e5m2_t {
using cute::float_e5m2_t::float_e5m2_t;
CUTLASS_HOST_DEVICE
float_e5m2_t() = default;

CUTLASS_HOST_DEVICE
explicit float_e5m2_t(__nv_bfloat16 x)
: float_e5m2_t(static_cast<float>(x)) {}
};

template <typename T> struct to_cute_type {
using type = T;
};
template <> struct to_cute_type<tl::float_e4m3_t> {
using type = cute::float_e4m3_t;
};
template <> struct to_cute_type<tl::float_e5m2_t> {
using type = cute::float_e5m2_t;
};

} // namespace tl

namespace cutlass {
Expand Down
5 changes: 3 additions & 2 deletions src/tl_templates/cuda/cuda_fp8.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
#pragma once

#include "common.h"
#include <cuda_fp8.h>
#include <cute/numeric/numeric_types.hpp>

using fp8_e4_t = cute::float_e4m3_t;
using fp8_e5_t = cute::float_e5m2_t;
using fp8_e4_t = tl::float_e4m3_t;
using fp8_e5_t = tl::float_e5m2_t;

struct __CUDA_ALIGN__(2) fp8_e4_2_t {
fp8_e4_t x;
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,14 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;

using Instruction =
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sm100.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,14 @@ template <int M, int N, int K, int AtomM, int AtomN, int AtomK, bool trans_A,
typename C_type_raw>
class GemmTensorOp {
public:
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type =
typename std::conditional<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>::type;
typename std::conditional<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>::type;
using B_type =
typename std::conditional<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>::type;
typename std::conditional<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>::type;
using C_type = C_type_raw;

static_assert(AtomM == 128 || AtomM == 64 || AtomM == 32);
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@ template <int M, int N, int K, int num_warp_m, int num_warp_n, bool trans_A,
typename B_type_raw, typename C_type_raw>
class GemmTensorOp {
public:
using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using C_type = C_type_raw;

static constexpr GMMA::Major GmmaMajorA =
Expand Down
10 changes: 6 additions & 4 deletions src/tl_templates/cuda/gemm_sp_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ class GemmTensorOp {
public:
static_assert(num_warp_m % 4 == 0, "num_warp_m must be a multiple of 4");

using A_type = conditional_t<std::is_same<A_type_raw, float>::value,
tfloat32_t, A_type_raw>;
using B_type = conditional_t<std::is_same<B_type_raw, float>::value,
tfloat32_t, B_type_raw>;
using A_type_cute = typename tl::to_cute_type<A_type_raw>::type;
using B_type_cute = typename tl::to_cute_type<B_type_raw>::type;
using A_type = conditional_t<std::is_same<A_type_cute, float>::value,
tfloat32_t, A_type_cute>;
using B_type = conditional_t<std::is_same<B_type_cute, float>::value,
tfloat32_t, B_type_cute>;
using C_type = C_type_raw;

static constexpr bool need_tfloat32_cast =
Expand Down
Loading