-
Notifications
You must be signed in to change notification settings - Fork 430
Remove sm100+ requirment for trtllm allreduce kernels #1249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b8f716b
5b0b7f0
fae9359
0d7a149
39baa7b
2354963
bd1711b
355f17b
f8e0a45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
#include <cooperative_groups.h> | ||
#include <cuda_bf16.h> | ||
#include <cuda_fp16.h> | ||
#include <cuda_fp4.h> | ||
|
||
#include <cuda/std/optional> | ||
#include <tuple> | ||
|
@@ -522,7 +523,17 @@ __device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(std::optional<int> batchI | |
return nullptr; | ||
} | ||
|
||
__forceinline__ __device__ uint32_t pack_bytes(uint8_t c0, uint8_t c1, uint8_t c2, uint8_t c3) { | ||
uint32_t val0 = c0; | ||
uint32_t val1 = c1; | ||
uint32_t val2 = c2; | ||
uint32_t val3 = c3; | ||
|
||
return (val3 << 24) | (val2 << 16) | (val1 << 8) | val0; | ||
} | ||
|
||
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). | ||
// NOTE: bypass sm_100 requirement by __nv_cvt_float2_to_fp4x2 | ||
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { | ||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||
uint32_t val; | ||
|
@@ -543,8 +554,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { | |
"f"(array[6]), "f"(array[7])); | ||
return val; | ||
#else | ||
// static_assert(false, "not supported."); | ||
return 0; | ||
uint32_t val; | ||
__nv_fp4x2_storage_t vals[4]; | ||
#pragma unroll | ||
for (int i = 0; i < 4; i++) { | ||
vals[i] = __nv_cvt_float2_to_fp4x2(*(((float2*)array) + i), __NV_E2M1, cudaRoundNearest); | ||
} | ||
Comment on lines
+560
to
+562
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a C-style cast
|
||
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]); | ||
return val; | ||
#endif | ||
} | ||
|
||
|
@@ -569,8 +586,14 @@ inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { | |
"f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); | ||
return val; | ||
#else | ||
// static_assert(false, "not supported."); | ||
return 0; | ||
uint32_t val; | ||
__nv_fp4x2_storage_t vals[4]; | ||
#pragma unroll | ||
for (int i = 0; i < 4; i++) { | ||
vals[i] = __nv_cvt_float2_to_fp4x2(array[i], __NV_E2M1, cudaRoundNearest); | ||
} | ||
val = pack_bytes(vals[0], vals[1], vals[2], vals[3]); | ||
return val; | ||
#endif | ||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.