Skip to content

Remove sm100+ requirment for trtllm allreduce kernels #1249

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ class FP4QuantizationSFLayout:


def gen_trtllm_comm_module() -> JitSpec:
major, minor = torch.cuda.get_device_capability()
return gen_jit_spec(
"trtllm_comm",
[
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_allreduce_fusion.cu",
jit_env.FLASHINFER_CSRC_DIR / "trtllm_moe_allreduce_fusion.cu",
],
extra_cuda_cflags=sm100a_nvcc_flags,
extra_cuda_cflags=sm100a_nvcc_flags if major >= 10 and minor >= 0 else [],
)


Expand Down
31 changes: 27 additions & 4 deletions include/flashinfer/comm/trtllm_allreduce_fusion.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp4.h>

#include <cuda/std/optional>
#include <tuple>
Expand Down Expand Up @@ -522,7 +523,17 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
return nullptr;
}

__forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) {
uint32_t val0 = c0;
uint32_t val1 = c1;
uint32_t val2 = c2;
uint32_t val3 = c3;

return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
}

// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
Expand All @@ -543,8 +554,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
"f"(array[6]), "f"(array[7]));
return val;
#else
// static_assert(false, "not supported.");
return 0;
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest);
}
Comment on lines +560 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a C-style cast (float2*) is generally discouraged in C++. It's better to use reinterpret_cast for this kind of type punning to make the intent clearer and the code safer. Additionally, for performance-critical code like this, consider adding #pragma unroll before the for loop to encourage the compiler to unroll it, which can improve performance.

    vals[i] = __nv_cvt_float2_to_fp4x2(reinterpret_cast<float2*>(array) + i, __NV_E2M1, cudaRoundNearest);

val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

Expand All @@ -569,8 +586,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
// static_assert(false, "not supported.");
return 0;
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

Expand Down
31 changes: 27 additions & 4 deletions include/flashinfer/comm/trtllm_moe_allreduce_fusion.cuh
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp4.h>

#include <cuda/std/optional>
#include <tuple>
Expand Down Expand Up @@ -509,7 +510,17 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI
return nullptr;
}

__forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) {
uint32_t val0 = c0;
uint32_t val1 = c1;
uint32_t val2 = c2;
uint32_t val3 = c3;

return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0;
}

// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
// NOTE:bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
uint32_t val;
Expand All @@ -530,8 +541,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
"f"(array[6]), "f"(array[7]));
return val;
#else
// static_assert(false, "not supported.");
return 0;
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

Expand All @@ -556,8 +573,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
return val;
#else
// static_assert(false, "not supported.");
return 0;
uint32_t val;
__nv_fp4x2_storage_t vals[4];
#pragma unroll
for (int i = 0; i < 4; i++) {
vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest);
}
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]);
return val;
#endif
}

Expand Down