Skip to content

deepseek && qwen tp performance tuning #934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open

Conversation

shihaobai
Copy link
Collaborator

No description provided.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request focuses on significant performance optimizations for Deepseek and Qwen models, particularly within Mixture of Experts (MoE) and Tensor Parallelism (TP) contexts. It introduces new, highly-tuned Triton kernels for grouped GEMM operations, refines weight fusion and data handling for MoE, and enhances FP8 quantization. Additionally, it includes specialized Rotary Embedding kernels and integrates FlashInfer for improved CUDA architecture support, all aimed at boosting inference speed and efficiency.

Highlights

  • MoE Kernel Tuning: New kernel configurations are added for grouped MoE GEMM operations, specifically optimized for NVIDIA H100 and H200 GPUs across various model dimensions and quantization settings.
  • MoE Weight Fusion Optimization: The process of fusing MoE expert weights and their scales is made more memory-efficient by pre-allocating tensors and using direct slicing instead of concatenation and flattening.
  • Grouped MoE GEMM Data Handling: The moe_align2 and grouped_matmul functions are refactored to improve data access patterns and reduce data movement during grouped GEMM computations.
  • SiLU and Mul Kernel Optimization: A new, faster Triton kernel (_silu_and_mul_kernel_fast) is introduced for the SiLU and Mul activation function, specifically benefiting smaller batch sizes in MoE layers.
  • FP8 Quantization Enhancements: The FP8 input quantization process is made more flexible and efficient, particularly concerning the handling and alignment of scales.
  • Rotary Embedding Kernel Split: The Rotary Embedding kernel is split into separate, optimized kernels for Query and Key tensors, allowing for more fine-grained performance tuning.
  • FlashInfer Integration: Conditional import and usage of FlashInfer are added to dynamically set CUDA architecture flags, improving compatibility and performance with FlashInfer-enabled operations.
  • Benchmark and Tuning Updates: New benchmarks are introduced for fused MoE Triton kernels, and existing tuning scripts are updated to support new FP8 quantization parameters and DeepseekV2 large model configurations.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@shihaobai shihaobai force-pushed the fused_moe_improve branch from 52dc131 to 6362c4a Compare June 19, 2025 16:04
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request focuses on performance tuning for MoE layers, particularly for DeepSeek and Qwen models, with significant enhancements to FP8 quantization and Triton kernel optimizations. Key changes include refactoring MoE data alignment and GEMM kernels, optimizing weight fusing, specializing rotary embedding kernels, and updating quantization utilities. A new benchmark for fused MoE has also been added. The changes are generally aimed at improving efficiency and memory usage.

Comment on lines +96 to +106
gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape
up_out_dim, up_in_dim = self.experts_up_projs[0].shape
assert gate_in_dim == up_in_dim
dtype = self.experts_gate_projs[0].dtype
total_expert_num = self.n_routed_experts

w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu")

for i_experts in range(self.n_routed_experts):
expert_gate_up_proj = torch.cat(
[self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0
)
expert_gate_up_proj = expert_gate_up_proj
w1_list.append(expert_gate_up_proj)

inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1]
w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size)
w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts]
w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The change to pre-allocate w1 (and similarly w1_scale in _fuse_weight_scale) and then fill it using slicing is a good optimization. This approach avoids the creation of intermediate lists of tensors and the subsequent overhead of torch.cat and torch._utils._flatten_dense_tensors (for w1_list), which should lead to reduced peak memory usage and potentially faster execution.

Comment on lines 380 to +390

if expert_id == -1:
return

tile_m_idx = tl.load(mblocks_to_m_index + pid_m)
tile_n_idx = pid_n

# get the gemm size of the current problem
cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last")

# do regular gemm here
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
token_mask = offs_am < cur_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
# token_mask = offs_am < cur_m
a_m_index = tl.load(
expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am,
mask=token_mask,
other=0,
padded_expert_to_token_index + offs_am,
)
if MUL_ROUTED_WEIGHT:
a_m_scale = tl.load(
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
mask=token_mask,
other=0.0,
)

token_mask = a_m_index != -1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The indexing logic for a_m_index and the subsequent token_mask derivation has been updated to use the new padded_expert_to_token_index.

Previously, tile_m_idx was loaded from mblocks_to_m_index, and cur_m (expert-specific token count) was used for masking. Now, offs_am is directly computed from pid_m, and a_m_index is loaded from padded_expert_to_token_index + offs_am. The token_mask is then derived from a_m_index != -1.

This simplification relies on padded_expert_to_token_index being correctly populated by moe_align2 (with -1 for padding). Ensure thorough testing, as this is a critical part of the GEMM computation.

Comment on lines +456 to 461
a_m_scale = tl.load(
padded_expert_to_weights + offs_am,
mask=token_mask,
other=0.0,
)
accumulator *= a_m_scale[:, None]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When MUL_ROUTED_WEIGHT is true, a_m_scale is now loaded from padded_expert_to_weights + offs_am. This is consistent with the changes in moe_align2 which now prepares padded_expert_to_weights. This ensures that the correct routing weights are applied per token.

Comment on lines +57 to +102
@triton.jit
def _silu_and_mul_kernel_fast(
input_ptr,
output_ptr,
stride_input_m,
stride_input_n,
stride_output_m,
stride_output_n,
size_n,
BLOCK_N: tl.constexpr,
NEED_MASK: tl.constexpr,
):
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)

cur_batch = tl.program_id(0)
pid = tl.program_id(1)
n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)

up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n)
gate_offsets = cur_batch * stride_input_m + n_offsets[None, :]
res_offsets = cur_batch * stride_output_m + n_offsets[None, :]
if NEED_MASK:
mask = n_offsets[None, :] < size_n
else:
mask = True

up = tl.load(
input_ptr + up_offsets,
mask=mask,
other=0.0,
)
gate = tl.load(
input_ptr + gate_offsets,
mask=mask,
other=0.0,
).to(tl.float32)

gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)

tl.store(
output_ptr + res_offsets,
up * gate,
mask=mask,
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new _silu_and_mul_kernel_fast kernel is a good addition for optimizing SiLU and multiply operations when size_m (number of rows/tokens) is relatively small (<= 4096).

The grid launch (size_m, triton.cdiv(size_n, BLOCK_N)) with cur_batch = tl.program_id(0) effectively assigns each row to a separate Triton program instance. This row-wise processing can be more efficient than a blocked approach for smaller M.

Comment on lines 52 to +57
qinput_tensor = self.cache_manager.alloc_tensor(
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False
)
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale)
input_scale = tma_align_input_scale(input_scale)
_, input_scale = per_token_group_quant_fp8(
input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The per_token_group_quant_fp8 function is now called with column_major_scales=True and scale_tma_aligned=True. The returned input_scale is used directly, and the previous explicit call to tma_align_input_scale is removed.
This is a good refactoring as it encapsulates the TMA alignment logic for scales within the per_token_group_quant_fp8 function itself.

Comment on lines 109 to +150
x: torch.Tensor,
group_size: int,
x_q: torch.Tensor,
x_s: torch.Tensor,
x_s: torch.Tensor = None,
eps: float = 1e-10,
dtype: torch.dtype = torch.float8_e4m3fn,
column_major_scales: bool = False,
scale_tma_aligned: bool = False,
alloc_func: Callable = torch.empty,
):
# Adapted from
# https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290
if HAS_SGL_KERNEL:
finfo = torch.finfo(dtype)
fp8_max, fp8_min = finfo.max, finfo.min
if column_major_scales:
if scale_tma_aligned:
# aligned to 4 * sizeof(float)
aligned_size = (x.shape[-2] + 3) // 4 * 4
x_s = alloc_func(
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
device=x.device,
dtype=torch.float32,
).permute(-1, -2)[: x.shape[-2], :]
else:
x_s = alloc_func(
(x.shape[-1] // group_size,) + x.shape[:-1],
device=x.device,
dtype=torch.float32,
).permute(-1, -2)
else:
if x_s is None:
x_s = alloc_func(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max)
else:
lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn)

return x_q, x_s

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The per_token_group_quant_fp8 function has been significantly enhanced:

  1. x_s (scales tensor) is now optional and can be allocated internally.
  2. New parameters column_major_scales and scale_tma_aligned allow for more control over the scale tensor's layout and alignment.
  3. The alloc_func parameter adds flexibility for tensor allocation.
  4. The function now returns (x_q, x_s).

These changes make the function more versatile and self-contained.

Comment on lines 68 to 90
w1_scale = w2_scale = None

if use_fp8_w8a8:
init_dtype = dtype
w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda()
w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda()
w1 = w1.to(torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fn)

if block_shape is None:
w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda()
w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda()
else:
block_n, block_k = block_shape[0], block_shape[1]
n_tiles_w1 = (2 * n + block_n - 1) // block_n
n_tiles_w2 = (k + block_n - 1) // block_n
k_tiles_w1 = (k + block_k - 1) // block_k
k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k
w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda()
w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda()
else:
w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda()
w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The weight and scale initialization logic for FP8 (use_fp8_w8a8=True) has been significantly revised. Weights (w1, w2) are now directly converted to torch.float8_e4m3fn, and their scales (w1_scale, w2_scale) are initialized based on block_shape (supporting per-expert or block-wise scales).

The explicit call to quantize_moe has been removed. This implies that the grouped_matmul kernel is now expected to consume FP8 weights and their corresponding scales directly.

Comment on lines 232 to 233
block_shape = getattr(model_config, "block_shape", None)
block_shape = [128, 128]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In the benchmark function, block_shape is first potentially retrieved from model_config. However, it's immediately overwritten by block_shape = [128, 128].

If the intention is to use a dynamic block_shape from the model configuration, the overwrite should be removed or made conditional. If [128, 128] is a specific override for this benchmark, it might be clearer to document this or use a different variable name for the override.

Suggested change
block_shape = getattr(model_config, "block_shape", None)
block_shape = [128, 128]
dtype = model_config["dtype"]
# block_shape = model_config.get("block_shape", None) # Get from config if needed
# For this specific benchmark, we might be testing a fixed block_shape:
block_shape_override = [128, 128] # Or use model_config["block_shape"] if available and desired
x = torch.randn(num_tokens, hidden_size, dtype=dtype)

@shihaobai shihaobai changed the title deepseek && qwen tp performance tuning deepseek && qwen tp performance tuning and doc updated Jun 23, 2025
@shihaobai shihaobai changed the title deepseek && qwen tp performance tuning and doc updated deepseek && qwen tp performance tuning && doc updated Jun 23, 2025
@shihaobai shihaobai force-pushed the fused_moe_improve branch from 47a243a to 587e0a4 Compare June 24, 2025 11:59
@shihaobai shihaobai changed the title deepseek && qwen tp performance tuning && doc updated deepseek && qwen tp performance tuning Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants