Skip to content

[Bugfix] Fix topk_ids indices_type for CUTLASS w8a8 FP8 MoE #20166

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions csrc/quantization/cutlass_w8a8/moe/moe_data.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

constexpr uint64_t THREADS_PER_EXPERT = 512;

__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
int32_t* problem_sizes1,
int32_t* problem_sizes2,
int32_t* atomic_buffer,
Expand Down Expand Up @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets(
}
}

__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
const int32_t* __restrict__ expert_offsets,
int32_t* input_permutation,
int32_t* output_permutation,
Expand Down Expand Up @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller(

int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
Expand All @@ -120,7 +120,7 @@ void get_cutlass_moe_mm_data_caller(
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
}
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
static_cast<const uint32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(topk_ids.data_ptr()),
static_cast<const int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(input_permutation.data_ptr()),
static_cast<int32_t*>(output_permutation.data_ptr()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,8 +621,7 @@ def apply(
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.uint32)
e_score_correction_bias=e_score_correction_bias)

return self.fused_experts(
x,
Expand Down