-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Deepseek-V2 #4650
Support Deepseek-V2 #4650
Changes from 1 commit
5688e58
2609d43
2bcfba8
36425b0
28199d8
434d757
ce3a80a
59b6353
1ce0c2a
bf98862
ca9c0ee
4cf44a5
0746b4f
2443f27
44f087c
df65a69
1d90229
e06d0d2
703e6a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -322,6 +322,8 @@ def fused_moe( | |
w2_scale: Optional[torch.Tensor] = None, | ||
a1_scale: Optional[torch.Tensor] = None, | ||
a2_scale: Optional[torch.Tensor] = None, | ||
num_expert_group: int = 0, | ||
topk_group: int = 0, | ||
) -> torch.Tensor: | ||
""" | ||
This function computes a Mixture of Experts (MoE) layer using two sets of | ||
|
@@ -362,35 +364,46 @@ def fused_moe( | |
] | ||
M, _ = hidden_states.shape | ||
E, N, _ = w1.shape | ||
|
||
if is_hip(): | ||
# The MoE kernels are not yet supported on ROCm. | ||
routing_weights = torch.softmax(gating_output, | ||
dim=-1, | ||
dtype=torch.float32) | ||
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) | ||
else: | ||
import vllm._moe_C as moe_kernels | ||
|
||
topk_weights = torch.empty(M, | ||
topk, | ||
dtype=torch.float32, | ||
device=hidden_states.device) | ||
topk_ids = torch.empty(M, | ||
topk, | ||
dtype=torch.int32, | ||
device=hidden_states.device) | ||
token_expert_indicies = torch.empty(M, | ||
topk, | ||
dtype=torch.int32, | ||
device=hidden_states.device) | ||
moe_kernels.topk_softmax( | ||
topk_weights, | ||
topk_ids, | ||
token_expert_indicies, | ||
gating_output.float(), # TODO(woosuk): Optimize this. | ||
) | ||
del token_expert_indicies # Not used. Will be used in the future. | ||
if num_expert_group == 0: | ||
import vllm._moe_C as moe_kernels | ||
|
||
topk_weights = torch.empty(M, | ||
topk, | ||
dtype=torch.float32, | ||
device=hidden_states.device) | ||
topk_ids = torch.empty(M, | ||
topk, | ||
dtype=torch.int32, | ||
device=hidden_states.device) | ||
token_expert_indicies = torch.empty(M, | ||
topk, | ||
dtype=torch.int32, | ||
device=hidden_states.device) | ||
moe_kernels.topk_softmax( | ||
topk_weights, | ||
topk_ids, | ||
token_expert_indicies, | ||
gating_output.float(), # TODO(woosuk): Optimize this. | ||
) | ||
del token_expert_indicies # Not used. Will be used in the future. | ||
else: | ||
scores = torch.softmax(gating_output, dim = -1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since these lines don't really share anything with the rest of the function, to avoid too many if conditions here, it will be best to make a new function
and then call the |
||
num_token = scores.shape[0] | ||
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] | ||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group] | ||
group_mask = torch.zeros_like(group_scores) # [n, n_group] | ||
group_mask.scatter_(1, group_idx, 1) # [n, n_group] | ||
score_mask = group_mask.unsqueeze(-1).expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] | ||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] | ||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False) | ||
|
||
if renormalize: | ||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -452,6 +452,116 @@ def forward( | |
return query.flatten(-2), key.flatten(-2) | ||
|
||
|
||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: | ||
if scale <= 1: | ||
return 1.0 | ||
return 0.1 * mscale * math.log(scale) + 1.0 | ||
|
||
|
||
class DeepseekScalingRotaryEmbedding(RotaryEmbedding): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is extremely similar to |
||
"""RotaryEmbedding extended with YaRN method. | ||
|
||
Credits to Peng et al. github.com/jquesnelle/yarn | ||
""" | ||
|
||
def __init__( | ||
self, | ||
head_size: int, | ||
rotary_dim: int, | ||
max_position_embeddings: int, | ||
base: int, | ||
is_neox_style: bool, | ||
scaling_factor: float, | ||
*, | ||
extrapolation_factor: float = 1, | ||
attn_factor: float = 1, | ||
beta_fast: float = 32, | ||
beta_slow: float = 1, | ||
mscale: float = 1, | ||
mscale_all_dim: float = 0, | ||
) -> None: | ||
self.scaling_factor = scaling_factor | ||
self.extrapolation_factor = extrapolation_factor | ||
self.attn_factor = attn_factor | ||
self.beta_fast = beta_fast | ||
self.beta_slow = beta_slow | ||
# Get n-d magnitude scaling corrected for interpolation. | ||
self.mscale = float( | ||
yarn_get_mscale(self.scaling_factor, float(mscale)) | ||
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor) | ||
super().__init__(head_size, rotary_dim, max_position_embeddings, base, | ||
is_neox_style) | ||
|
||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: | ||
pos_freqs = self.base**(torch.arange( | ||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / | ||
self.rotary_dim) | ||
inv_freq_extrapolation = 1.0 / pos_freqs | ||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) | ||
|
||
low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, | ||
self.rotary_dim, self.base, | ||
self.max_position_embeddings) | ||
# Get n-d rotational scaling corrected for extrapolation | ||
inv_freq_mask = (1 - _yarn_linear_ramp_mask( | ||
low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor | ||
inv_freq = inv_freq_interpolation * ( | ||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask | ||
return inv_freq | ||
|
||
def _compute_cos_sin_cache(self) -> torch.Tensor: | ||
inv_freq = self._compute_inv_freq(self.scaling_factor) | ||
t = torch.arange(self.max_position_embeddings * self.scaling_factor, | ||
device="cuda", | ||
dtype=torch.float32) | ||
freqs = torch.einsum("i,j -> ij", t, inv_freq) | ||
cos = (freqs.cos() * self.mscale) | ||
sin = (freqs.sin() * self.mscale) | ||
cache = torch.cat((cos, sin), dim=-1) | ||
print("Cache shape", cache.shape) | ||
return cache | ||
|
||
|
||
def forward( | ||
self, | ||
positions: torch.Tensor, | ||
query: torch.Tensor, | ||
key: torch.Tensor, | ||
offsets: Optional[torch.Tensor] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor]: | ||
"""PyTorch-native implementation equivalent to forward().""" | ||
query_rot = query[..., :self.rotary_dim] | ||
key_rot = key[..., :self.rotary_dim] | ||
if self.rotary_dim < self.head_size: | ||
query_pass = query[..., self.rotary_dim:] | ||
key_pass = key[..., self.rotary_dim:] | ||
|
||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( | ||
positions.device) | ||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) | ||
if offsets is not None else positions] | ||
cos, sin = cos_sin.chunk(2, dim=-1) | ||
if self.is_neox_style: | ||
# NOTE(woosuk): Here we assume that the positions tensor has the | ||
# shape [batch_size, seq_len]. | ||
cos = cos.repeat(1, 1, 2).unsqueeze(-2) | ||
sin = sin.repeat(1, 1, 2).unsqueeze(-2) | ||
else: | ||
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) | ||
|
||
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj | ||
query_rot = query_rot * cos + rotate_fn(query_rot) * sin | ||
key_rot = key_rot * cos + rotate_fn(key_rot) * sin | ||
|
||
if self.rotary_dim < self.head_size: | ||
query = torch.cat((query_rot, query_pass), dim=-1) | ||
key = torch.cat((key_rot, key_pass), dim=-1) | ||
else: | ||
query = query_rot | ||
key = key_rot | ||
return query, key | ||
|
||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} | ||
|
||
|
||
|
@@ -506,6 +616,21 @@ def get_rope( | |
base, is_neox_style, | ||
scaling_factor, | ||
**extra_kwargs) | ||
elif scaling_type == "deepseek_yarn": | ||
original_max_position = rope_scaling[ | ||
"original_max_position_embeddings"] | ||
# assert max_position == original_max_position * scaling_factor | ||
extra_kwargs = { | ||
k: v | ||
for k, v in rope_scaling.items() | ||
if k in ("extrapolation_factor", "attn_factor", "beta_fast", | ||
"beta_slow", "mscale", "mscale_all_dim") | ||
} | ||
rotary_emb = DeepseekScalingRotaryEmbedding(head_size, rotary_dim, | ||
original_max_position, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It can be confusing to call the |
||
base, is_neox_style, | ||
scaling_factor, | ||
**extra_kwargs) | ||
elif scaling_type == "su": | ||
short_factor = rope_scaling["short_factor"] | ||
long_factor = rope_scaling["long_factor"] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add the
head_dim
to the huggingface config instead of hard coding this here?