Skip to content

Feature/sm100 low latency nvfp4 kernels #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -729,13 +729,14 @@ __device__ uint32_t quantizePackedFP4Value(
ComputeElem& post_act_val, float global_scale_val, int64_t num_tokens_before_expert,
int64_t expert_id, int64_t token_id, int64_t elem_idx, int64_t num_cols,
int64_t max_tokens_per_expert, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat) {
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
static constexpr int CVT_FP4_SF_VEC_SIZE = 16;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / CVT_ELTS_PER_THREAD;
// Quantize the input to FP4
static_assert(std::is_same_v<GemmOutputType, __nv_bfloat16> ||
std::is_same_v<GemmOutputType, half>);
static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD);
static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD);
PackedVec<GemmOutputType> packed_vec{};
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) {
packed_vec.elts[i].x = static_cast<GemmOutputType>(post_act_val[i * 2 + 0]);
packed_vec.elts[i].y = static_cast<GemmOutputType>(post_act_val[i * 2 + 1]);
}
Expand All @@ -746,14 +747,15 @@ __device__ uint32_t quantizePackedFP4Value(

// Use `token - num_tokens_before_expert` because we want this to be relative to the start of this
// expert
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
auto sf_out = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx,
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
std::nullopt /* numRows */, num_cols, act_sf_expert, QuantizationSFLayout::SWIZZLED);

// Do the conversion and set the output and scaling factor
constexpr bool UE8M0 = false;
auto res = cvt_warp_fp16_to_fp4<GemmOutputType, UE8M0>(packed_vec, global_scale_val, sf_out);
auto res = cvt_warp_fp16_to_fp4<GemmOutputType, CVT_FP4_SF_VEC_SIZE, UE8M0>(
packed_vec, global_scale_val, sf_out);
return res;
}

Expand All @@ -762,25 +764,25 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id,
int64_t num_cols, int64_t max_tokens_per_expert,
TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat,
TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) {
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD;
static constexpr int CVT_FP4_SF_VEC_SIZE = 16;
static constexpr int CVT_FP4_NUM_THREADS_PER_SF = CVT_FP4_SF_VEC_SIZE / CVT_ELTS_PER_THREAD;

// We need to offset into the scaling factors for just this expert
auto act_sf_expert =
act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols);

// Use `token - num_tokens_before_expert` because we want this to be relative to the start of this
// expert
auto sf_out = cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
auto sf_out = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx,
std::nullopt /* numRows */, num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED);
std::nullopt /* numRows */, num_cols, act_sf_expert, QuantizationSFLayout::SWIZZLED);
if (sf_out) {
auto const sf_in =
cvt_quant_to_fp4_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
num_cols, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
FP4QuantizationSFLayout::SWIZZLED);
auto const sf_in = cvt_quant_get_sf_out_offset<TmaWarpSpecializedGroupedGemmInput::ElementSF,
CVT_FP4_NUM_THREADS_PER_SF>(
std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */,
num_cols, const_cast<TmaWarpSpecializedGroupedGemmInput::ElementSF*>(input_sf),
QuantizationSFLayout::SWIZZLED);
*sf_out = *sf_in;
}
}
Expand Down Expand Up @@ -1178,7 +1180,7 @@ __global__ void expandInputRowsKernel(
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD =
is_fp4 ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits<InputActivationsType>::value);
is_fp4 ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits<InputActivationsType>::value);
constexpr int64_t ELEM_PER_BYTE = is_fp4_input ? 2 : 1;
using DataElem = std::conditional_t<is_fp4_input, uint32_t,
cutlass::Array<InputActivationsType, ELEM_PER_THREAD>>;
Expand Down Expand Up @@ -1531,7 +1533,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,

// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t ACTIVATION_ELEM_PER_THREAD =
IsFP4 ? CVT_FP4_ELTS_PER_THREAD
IsFP4 ? CVT_ELTS_PER_THREAD
: (128 / std::min(sizeof_bits<T>::value, sizeof_bits<GemmOutputType>::value));

using BiasElem = cutlass::Array<ScaleBiasType, ACTIVATION_ELEM_PER_THREAD>;
Expand Down
Loading