-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Misc] Fixes and Optimizations for DeepEP + DeepGEMM combination. #19298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50e4e80
42583e9
bd0dafe
2a9d99d
d3a6300
00aa897
8ac9fec
edd27de
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -80,11 +80,13 @@ def workspace_shapes( | |
topk: int, | ||
num_experts: int, | ||
) -> tuple[int, int, torch.dtype]: | ||
|
||
block_m = self.block_shape[0] | ||
M_sum = (M * topk) + num_experts * (block_m - 1) | ||
M_sum = round_up(M_sum, block_m) | ||
workspace1 = M_sum * max(N * 2, K) | ||
workspace2 = M_sum * N | ||
workspace2 = M_sum * max(N, K) | ||
tlrmchlsmth marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
return (workspace1, workspace2, a.dtype) | ||
|
||
def apply( | ||
|
@@ -135,26 +137,31 @@ def apply( | |
|
||
# Note: M_sum is different than the pre-permuted shape of a1q. | ||
M_sum = a1q.size(0) | ||
workspace1 = _resize_cache(workspace13, (M_sum, N)) | ||
workspace2 = _resize_cache(workspace2, (M_sum, N // 2)) | ||
workspace3 = _resize_cache(workspace13, (M_sum, K)) | ||
|
||
mm1_out = _resize_cache(workspace13, (M_sum, N)) | ||
act_out = _resize_cache(workspace2, (M_sum, N // 2)) | ||
quant_out = _resize_cache(workspace13.view(dtype=torch.float8_e4m3fn), | ||
(M_sum, N // 2)) | ||
mm2_out = _resize_cache(workspace2, (M_sum, K)) | ||
out = _resize_cache(workspace13, (inv_perm.size(0), K)) | ||
|
||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( | ||
(a1q, a1q_scale), (w1, w1_scale), workspace1, expert_ids) | ||
(a1q, a1q_scale), (w1, w1_scale), mm1_out, expert_ids) | ||
|
||
self.activation(activation, workspace2, workspace1.view(-1, N)) | ||
self.activation(activation, act_out, mm1_out.view(-1, N)) | ||
|
||
a2q_scale: Optional[torch.Tensor] = None | ||
a2q, a2q_scale = per_token_group_quant_fp8(workspace2, | ||
a2q, a2q_scale = per_token_group_quant_fp8(act_out, | ||
self.block_shape[1], | ||
column_major_scales=True) | ||
column_major_scales=True, | ||
out_q=quant_out) | ||
|
||
dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( | ||
(a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) | ||
(a2q, a2q_scale), (w2, w2_scale), mm2_out, expert_ids) | ||
|
||
workspace3 = workspace3[inv_perm, ...] | ||
torch.index_select(mm2_out, 0, inv_perm, out=out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. memory optimization to prevent inv_perm from making a brand-new tensor. |
||
|
||
return workspace3 | ||
return out | ||
|
||
|
||
def deep_gemm_moe_fp8( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
import torch | ||
|
||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk | ||
from vllm import _custom_ops as ops | ||
from vllm.model_executor.layers.fused_moe.utils import ( | ||
moe_kernel_quantize_input) | ||
|
||
|
@@ -193,20 +194,23 @@ def _apply_weights_and_reduce(self, num_tokens: int, | |
apply_router_weight_on_input: bool, | ||
output_dtype: torch.dtype): | ||
|
||
hidden_dim = fused_expert_output.size(-1) | ||
if fused_expert_output.ndim == 2: | ||
hidden_dim = fused_expert_output.size(-1) | ||
fused_expert_output = fused_expert_output.view( | ||
num_tokens, -1, hidden_dim) | ||
|
||
if not apply_router_weight_on_input: | ||
# The DeepEP combine kernels don't do the topk weight | ||
# multiplication. We multiply the weights locally. | ||
fused_expert_output = fused_expert_output.to(torch.float32) | ||
fused_expert_output = fused_expert_output * topk_weights.view( | ||
fused_expert_output.size(0), -1, 1) | ||
fused_expert_output = fused_expert_output.to(output_dtype) | ||
m_x_topk = fused_expert_output.size(0) | ||
fused_expert_output.mul_(topk_weights.view(m_x_topk, -1, 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The in-place multiplication There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. topk_weights is float32 - and fused_expert_output is bfloat16 - the multiplication relies on type promotion/demotion. |
||
|
||
return fused_expert_output.sum(dim=1).to(output_dtype) | ||
out = torch.empty((num_tokens, hidden_dim), | ||
device=fused_expert_output.device, | ||
dtype=output_dtype) | ||
ops.moe_sum(fused_expert_output, out) | ||
|
||
return out | ||
|
||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, | ||
topk_weights: torch.Tensor, topk_ids: torch.Tensor, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -234,8 +234,13 @@ def _per_token_group_quant_fp8( | |
row = g_id // groups_per_row | ||
row_g_id = g_id % groups_per_row | ||
|
||
y_ptr += (row * y_row_stride) + (row_g_id * group_size) | ||
y_q_ptr += g_id * group_size | ||
# Ensure offset calculations use int64 to prevent overflow | ||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * | ||
group_size) | ||
y_ptr += y_ptr_offset | ||
|
||
y_q_ptr_offset = g_id.to(tl.int64) * group_size | ||
y_q_ptr += y_q_ptr_offset | ||
y_s_ptr += g_id | ||
|
||
cols = tl.arange(0, BLOCK) # N <= BLOCK | ||
|
@@ -282,15 +287,23 @@ def _per_token_group_quant_fp8_colmajor( | |
row = g_id // groups_per_row | ||
row_g_id = g_id % groups_per_row | ||
|
||
y_ptr += (row * y_row_stride) + (row_g_id * group_size) | ||
y_q_ptr += g_id * group_size | ||
# Ensure offset calculations use int64 to prevent overflow | ||
y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * | ||
group_size) | ||
y_ptr += y_ptr_offset | ||
|
||
y_q_ptr_offset = g_id.to(tl.int64) * group_size | ||
y_q_ptr += y_q_ptr_offset | ||
|
||
# Convert g_id the flattened block coordinate to 2D so we can index | ||
# into the output y_scales matrix | ||
blocks_per_row = y_num_columns // group_size | ||
scale_col = g_id % blocks_per_row | ||
scale_row = g_id // blocks_per_row | ||
y_s_ptr += scale_col * y_s_col_stride + scale_row | ||
# Ensure offset calculation uses int64 for y_s_ptr | ||
y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( | ||
tl.int64) | ||
y_s_ptr += y_s_ptr_offset | ||
|
||
cols = tl.arange(0, BLOCK) # group_size <= BLOCK | ||
mask = cols < group_size | ||
|
@@ -311,6 +324,7 @@ def per_token_group_quant_fp8( | |
eps: float = 1e-10, | ||
dtype: Optional[torch.dtype] = None, | ||
column_major_scales: bool = False, | ||
out_q: Optional[torch.Tensor] = None, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Function to perform per-token-group quantization on an input tensor `x`. | ||
It converts the tensor values into signed float8 values and returns the | ||
|
@@ -321,6 +335,8 @@ def per_token_group_quant_fp8( | |
eps: The minimum to avoid dividing zero. | ||
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` | ||
is supported for now. | ||
column_major_scales: Outputs scales in column major. | ||
out_q: Optional output tensor. If not provided, function will create. | ||
Returns: | ||
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the | ||
scaling factor for quantization. | ||
|
@@ -335,7 +351,11 @@ def per_token_group_quant_fp8( | |
fp8_min = finfo.min | ||
fp8_max = finfo.max | ||
|
||
x_q = torch.empty_like(x, device=x.device, dtype=dtype) | ||
assert out_q is None or out_q.shape == x.shape | ||
x_q = out_q | ||
if x_q is None: | ||
x_q = torch.empty_like(x, device=x.device, dtype=dtype) | ||
Comment on lines
+354
to
+357
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The addition of the |
||
|
||
M = x.numel() // group_size | ||
N = group_size | ||
if column_major_scales: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,15 @@ | |
import gc | ||
import time | ||
import weakref | ||
from contextlib import contextmanager | ||
from typing import TYPE_CHECKING, Any, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
import torch.distributed | ||
import torch.nn as nn | ||
|
||
import vllm.envs as envs | ||
from vllm.attention import AttentionType, get_attn_backend | ||
from vllm.attention.backends.abstract import (AttentionBackend, | ||
AttentionMetadataBuilder) | ||
|
@@ -1724,6 +1726,35 @@ def _get_prompt_logprobs_dict( | |
|
||
return prompt_logprobs_dict | ||
|
||
@contextmanager | ||
def maybe_randomize_inputs(self, input_ids: torch.Tensor): | ||
""" | ||
Randomize input_ids if VLLM_RANDOMIZE_DP_DUMMY_INPUTS is set. | ||
This is to help balance expert-selection | ||
- during profile_run | ||
- during DP rank dummy run | ||
""" | ||
dp_size = self.vllm_config.parallel_config.data_parallel_size | ||
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1 | ||
if not randomize_inputs: | ||
yield | ||
else: | ||
import functools | ||
|
||
@functools.cache | ||
def rand_input_ids() -> torch.Tensor: | ||
return torch.randint_like( | ||
self.input_ids, | ||
low=0, | ||
high=self.model_config.get_vocab_size(), | ||
dtype=input_ids.dtype) | ||
|
||
logger.debug("Randomizing dummy data for DP Rank") | ||
input_ids.copy_(rand_input_ids()[:input_ids.size(0)], | ||
non_blocking=True) | ||
yield | ||
input_ids.fill_(0) | ||
Comment on lines
+1753
to
+1756
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Restoring There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This could be optimized. We dont have to fill the input_ids and then set it to zeros every time. For eager-mode runs (i.e. batch size > 512) - we could just use the rand tensor in the place of input-ids .. I plan to do this in a follow up PR. |
||
|
||
@torch.inference_mode() | ||
def _dummy_run( | ||
self, | ||
|
@@ -1804,7 +1835,7 @@ def _dummy_run( | |
intermediate_tensors = self.sync_and_slice_intermediate_tensors( | ||
num_tokens, None, False) | ||
|
||
with set_forward_context( | ||
with self.maybe_randomize_inputs(input_ids), set_forward_context( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is incorrect to do this when we are doing profile runs -- during profile runs, we do want the system to be stress tested (i.e. all tokens reaching the same set of GPU ranks). #19168 should fix the OOM - then we can remove this logic for the profile run case. |
||
attn_metadata, | ||
self.vllm_config, | ||
num_tokens=num_tokens, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is DP important here? I think you would want this for any EP case, so maybe just VLLM_RANDOMIZE_DUMMY_INPUTS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Investigated it for a bit and I think it is better to call out DP in the name. It is only in the context of DP that some DP Ranks execute dummy-runs so we can synchronize with the DP Ranks that run the model with actual tokens.
— The other way we could do expert parallel is with DP=1 and TP > 1 - with this, all the ranks run with actual data (the input data is replicated across all ranks)
also, I have this statement in code,
randomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS and dp_size > 1
But I see what you are saying, we could do,
VLLM_RANDOMIZE_DUMMY_INPUTS
->VLLM_RANDOMIZE_DUMMY_INPUTS
andrandomize_inputs = envs.VLLM_RANDOMIZE_DP_DUMMY_INPUTS
and randomize if the env var is just set.let's do it when more use cases for randomizing dummy runs come up ? What do you think ?