-
Notifications
You must be signed in to change notification settings - Fork 264
deepseek && qwen tp performance tuning #934
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @shihaobai, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request focuses on significant performance optimizations for Deepseek and Qwen models, particularly within Mixture of Experts (MoE) and Tensor Parallelism (TP) contexts. It introduces new, highly-tuned Triton kernels for grouped GEMM operations, refines weight fusion and data handling for MoE, and enhances FP8 quantization. Additionally, it includes specialized Rotary Embedding kernels and integrates FlashInfer for improved CUDA architecture support, all aimed at boosting inference speed and efficiency.
Highlights
- MoE Kernel Tuning: New kernel configurations are added for grouped MoE GEMM operations, specifically optimized for NVIDIA H100 and H200 GPUs across various model dimensions and quantization settings.
- MoE Weight Fusion Optimization: The process of fusing MoE expert weights and their scales is made more memory-efficient by pre-allocating tensors and using direct slicing instead of concatenation and flattening.
- Grouped MoE GEMM Data Handling: The
moe_align2
andgrouped_matmul
functions are refactored to improve data access patterns and reduce data movement during grouped GEMM computations. - SiLU and Mul Kernel Optimization: A new, faster Triton kernel (
_silu_and_mul_kernel_fast
) is introduced for the SiLU and Mul activation function, specifically benefiting smaller batch sizes in MoE layers. - FP8 Quantization Enhancements: The FP8 input quantization process is made more flexible and efficient, particularly concerning the handling and alignment of scales.
- Rotary Embedding Kernel Split: The Rotary Embedding kernel is split into separate, optimized kernels for Query and Key tensors, allowing for more fine-grained performance tuning.
- FlashInfer Integration: Conditional import and usage of FlashInfer are added to dynamically set CUDA architecture flags, improving compatibility and performance with FlashInfer-enabled operations.
- Benchmark and Tuning Updates: New benchmarks are introduced for fused MoE Triton kernels, and existing tuning scripts are updated to support new FP8 quantization parameters and DeepseekV2 large model configurations.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
52dc131
to
6362c4a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request focuses on performance tuning for MoE layers, particularly for DeepSeek and Qwen models, with significant enhancements to FP8 quantization and Triton kernel optimizations. Key changes include refactoring MoE data alignment and GEMM kernels, optimizing weight fusing, specializing rotary embedding kernels, and updating quantization utilities. A new benchmark for fused MoE has also been added. The changes are generally aimed at improving efficiency and memory usage.
gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape | ||
up_out_dim, up_in_dim = self.experts_up_projs[0].shape | ||
assert gate_in_dim == up_in_dim | ||
dtype = self.experts_gate_projs[0].dtype | ||
total_expert_num = self.n_routed_experts | ||
|
||
w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") | ||
|
||
for i_experts in range(self.n_routed_experts): | ||
expert_gate_up_proj = torch.cat( | ||
[self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0 | ||
) | ||
expert_gate_up_proj = expert_gate_up_proj | ||
w1_list.append(expert_gate_up_proj) | ||
|
||
inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] | ||
w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) | ||
w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] | ||
w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change to pre-allocate w1
(and similarly w1_scale
in _fuse_weight_scale
) and then fill it using slicing is a good optimization. This approach avoids the creation of intermediate lists of tensors and the subsequent overhead of torch.cat
and torch._utils._flatten_dense_tensors
(for w1_list
), which should lead to reduced peak memory usage and potentially faster execution.
|
||
if expert_id == -1: | ||
return | ||
|
||
tile_m_idx = tl.load(mblocks_to_m_index + pid_m) | ||
tile_n_idx = pid_n | ||
|
||
# get the gemm size of the current problem | ||
cur_m = tl.load(expert_to_token_num + expert_id, eviction_policy="evict_last") | ||
|
||
# do regular gemm here | ||
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
token_mask = offs_am < cur_m | ||
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) | ||
# token_mask = offs_am < cur_m | ||
a_m_index = tl.load( | ||
expert_to_token_index + expert_id * expert_to_token_index_stride_0 + offs_am, | ||
mask=token_mask, | ||
other=0, | ||
padded_expert_to_token_index + offs_am, | ||
) | ||
if MUL_ROUTED_WEIGHT: | ||
a_m_scale = tl.load( | ||
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am, | ||
mask=token_mask, | ||
other=0.0, | ||
) | ||
|
||
token_mask = a_m_index != -1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The indexing logic for a_m_index
and the subsequent token_mask
derivation has been updated to use the new padded_expert_to_token_index
.
Previously, tile_m_idx
was loaded from mblocks_to_m_index
, and cur_m
(expert-specific token count) was used for masking. Now, offs_am
is directly computed from pid_m
, and a_m_index
is loaded from padded_expert_to_token_index + offs_am
. The token_mask
is then derived from a_m_index != -1
.
This simplification relies on padded_expert_to_token_index
being correctly populated by moe_align2
(with -1 for padding). Ensure thorough testing, as this is a critical part of the GEMM computation.
a_m_scale = tl.load( | ||
padded_expert_to_weights + offs_am, | ||
mask=token_mask, | ||
other=0.0, | ||
) | ||
accumulator *= a_m_scale[:, None] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@triton.jit | ||
def _silu_and_mul_kernel_fast( | ||
input_ptr, | ||
output_ptr, | ||
stride_input_m, | ||
stride_input_n, | ||
stride_output_m, | ||
stride_output_n, | ||
size_n, | ||
BLOCK_N: tl.constexpr, | ||
NEED_MASK: tl.constexpr, | ||
): | ||
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) | ||
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) | ||
|
||
cur_batch = tl.program_id(0) | ||
pid = tl.program_id(1) | ||
n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) | ||
|
||
up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n) | ||
gate_offsets = cur_batch * stride_input_m + n_offsets[None, :] | ||
res_offsets = cur_batch * stride_output_m + n_offsets[None, :] | ||
if NEED_MASK: | ||
mask = n_offsets[None, :] < size_n | ||
else: | ||
mask = True | ||
|
||
up = tl.load( | ||
input_ptr + up_offsets, | ||
mask=mask, | ||
other=0.0, | ||
) | ||
gate = tl.load( | ||
input_ptr + gate_offsets, | ||
mask=mask, | ||
other=0.0, | ||
).to(tl.float32) | ||
|
||
gate = gate / (1 + tl.exp(-gate)) | ||
gate = gate.to(input_ptr.dtype.element_ty) | ||
|
||
tl.store( | ||
output_ptr + res_offsets, | ||
up * gate, | ||
mask=mask, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new _silu_and_mul_kernel_fast
kernel is a good addition for optimizing SiLU and multiply operations when size_m
(number of rows/tokens) is relatively small (<= 4096).
The grid launch (size_m, triton.cdiv(size_n, BLOCK_N))
with cur_batch = tl.program_id(0)
effectively assigns each row to a separate Triton program instance. This row-wise processing can be more efficient than a blocked approach for smaller M
.
qinput_tensor = self.cache_manager.alloc_tensor( | ||
(m, k), qweight.dtype, device=qweight.device, is_graph_out=False | ||
) | ||
per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) | ||
input_scale = tma_align_input_scale(input_scale) | ||
_, input_scale = per_token_group_quant_fp8( | ||
input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The per_token_group_quant_fp8
function is now called with column_major_scales=True
and scale_tma_aligned=True
. The returned input_scale
is used directly, and the previous explicit call to tma_align_input_scale
is removed.
This is a good refactoring as it encapsulates the TMA alignment logic for scales within the per_token_group_quant_fp8
function itself.
x: torch.Tensor, | ||
group_size: int, | ||
x_q: torch.Tensor, | ||
x_s: torch.Tensor, | ||
x_s: torch.Tensor = None, | ||
eps: float = 1e-10, | ||
dtype: torch.dtype = torch.float8_e4m3fn, | ||
column_major_scales: bool = False, | ||
scale_tma_aligned: bool = False, | ||
alloc_func: Callable = torch.empty, | ||
): | ||
# Adapted from | ||
# https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290 | ||
if HAS_SGL_KERNEL: | ||
finfo = torch.finfo(dtype) | ||
fp8_max, fp8_min = finfo.max, finfo.min | ||
if column_major_scales: | ||
if scale_tma_aligned: | ||
# aligned to 4 * sizeof(float) | ||
aligned_size = (x.shape[-2] + 3) // 4 * 4 | ||
x_s = alloc_func( | ||
x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), | ||
device=x.device, | ||
dtype=torch.float32, | ||
).permute(-1, -2)[: x.shape[-2], :] | ||
else: | ||
x_s = alloc_func( | ||
(x.shape[-1] // group_size,) + x.shape[:-1], | ||
device=x.device, | ||
dtype=torch.float32, | ||
).permute(-1, -2) | ||
else: | ||
if x_s is None: | ||
x_s = alloc_func( | ||
x.shape[:-1] + (x.shape[-1] // group_size,), | ||
device=x.device, | ||
dtype=torch.float32, | ||
) | ||
sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max) | ||
else: | ||
lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) | ||
|
||
return x_q, x_s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The per_token_group_quant_fp8
function has been significantly enhanced:
x_s
(scales tensor) is now optional and can be allocated internally.- New parameters
column_major_scales
andscale_tma_aligned
allow for more control over the scale tensor's layout and alignment. - The
alloc_func
parameter adds flexibility for tensor allocation. - The function now returns
(x_q, x_s)
.
These changes make the function more versatile and self-contained.
test/kernel/fuse_moe_tuning_fp8.py
Outdated
w1_scale = w2_scale = None | ||
|
||
if use_fp8_w8a8: | ||
init_dtype = dtype | ||
w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda() | ||
w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda() | ||
w1 = w1.to(torch.float8_e4m3fn) | ||
w2 = w2.to(torch.float8_e4m3fn) | ||
|
||
if block_shape is None: | ||
w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda() | ||
w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda() | ||
else: | ||
block_n, block_k = block_shape[0], block_shape[1] | ||
n_tiles_w1 = (2 * n + block_n - 1) // block_n | ||
n_tiles_w2 = (k + block_n - 1) // block_n | ||
k_tiles_w1 = (k + block_k - 1) // block_k | ||
k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k | ||
w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda() | ||
w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda() | ||
else: | ||
w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() | ||
w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The weight and scale initialization logic for FP8 (use_fp8_w8a8=True
) has been significantly revised. Weights (w1
, w2
) are now directly converted to torch.float8_e4m3fn
, and their scales (w1_scale
, w2_scale
) are initialized based on block_shape
(supporting per-expert or block-wise scales).
The explicit call to quantize_moe
has been removed. This implies that the grouped_matmul
kernel is now expected to consume FP8 weights and their corresponding scales directly.
block_shape = getattr(model_config, "block_shape", None) | ||
block_shape = [128, 128] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the benchmark
function, block_shape
is first potentially retrieved from model_config
. However, it's immediately overwritten by block_shape = [128, 128]
.
If the intention is to use a dynamic block_shape
from the model configuration, the overwrite should be removed or made conditional. If [128, 128]
is a specific override for this benchmark, it might be clearer to document this or use a different variable name for the override.
block_shape = getattr(model_config, "block_shape", None) | |
block_shape = [128, 128] | |
dtype = model_config["dtype"] | |
# block_shape = model_config.get("block_shape", None) # Get from config if needed | |
# For this specific benchmark, we might be testing a fixed block_shape: | |
block_shape_override = [128, 128] # Or use model_config["block_shape"] if available and desired | |
x = torch.randn(num_tokens, hidden_size, dtype=dtype) |
47a243a
to
587e0a4
Compare
No description provided.