Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Deepseek-V2 #4650

Merged
merged 19 commits into from
Jun 28, 2024
7 changes: 7 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size

def get_head_size(self) -> int:
if hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type=='deepseek_v2':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the head_dim to the huggingface config instead of hard coding this here?

# FlashAttention suports only 32, 64, 128, 256, we need pading 192 to 256
return 256
if hasattr(self.hf_text_config, "head_dim"):
return self.hf_text_config.head_dim
# FIXME(woosuk): This may not be true for all models.
Expand All @@ -262,6 +265,8 @@ def get_total_num_kv_heads(self) -> int:
# NOTE: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
if hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type=='deepseek_v2':
zwd003 marked this conversation as resolved.
Show resolved Hide resolved
return self.hf_text_config.num_attention_heads
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
new_decoder_arch_falcon = (
self.hf_config.model_type in falcon_model_types
Expand Down Expand Up @@ -307,6 +312,8 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:

def get_num_attention_heads(self,
parallel_config: "ParallelConfig") -> int:
if hasattr(self.hf_text_config, "model_type") and self.hf_text_config.model_type=='deepseek_v2':
zwd003 marked this conversation as resolved.
Show resolved Hide resolved
return self.hf_text_config.num_attention_heads
return self.hf_text_config.num_attention_heads // \
parallel_config.tensor_parallel_size

Expand Down
56 changes: 35 additions & 21 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
Expand All @@ -326,27 +328,39 @@ def fused_topk(
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if num_expert_group == 0:
import vllm._moe_C as moe_kernels

topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
else:
scores = torch.softmax(gating_output, dim = -1)
Copy link
Collaborator

@pcmoritz pcmoritz May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since these lines don't really share anything with the rest of the function, to avoid too many if conditions here, it will be best to make a new function

def fused_grouped_topk(
    hidden_states: torch.Tensor,
    gating_output: torch.Tensor,
    topk: int,
    renormalize: bool,
    num_expert_group: int = 0,
    topk_group: int  = 0,
)

and then call the fused_grouped_topk and fused_experts functions separately like in https://github.com/vllm-project/vllm/pull/4652/files#diff-7f2897118411db2b0196a83729e1ae3dedb5a450b3abcb14020d9c6cdd716ad3R170

num_token = scores.shape[0]
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = group_mask.unsqueeze(-1).expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
Expand Down
125 changes: 125 additions & 0 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,116 @@ def forward(
return query.flatten(-2), key.flatten(-2)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is extremely similar to YaRNScalingRotaryEmbedding, can you extend that one instead to support mscale?

"""RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
scaling_factor: float,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
beta_fast: float = 32,
beta_slow: float = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale))
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor)
super().__init__(head_size, rotary_dim, max_position_embeddings, base,
is_neox_style)

def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow,
self.rotary_dim, self.base,
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq

def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
sin = (freqs.sin() * self.mscale)
cache = torch.cat((cos, sin), dim=-1)
print("Cache shape", cache.shape)
return cache


def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]

self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin

if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
return query, key

_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand Down Expand Up @@ -507,6 +617,21 @@ def get_rope(
base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "deepseek_yarn":
original_max_position = rope_scaling[
"original_max_position_embeddings"]
# assert max_position == original_max_position * scaling_factor
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("extrapolation_factor", "attn_factor", "beta_fast",
"beta_slow", "mscale", "mscale_all_dim")
}
rotary_emb = DeepseekScalingRotaryEmbedding(head_size, rotary_dim,
original_max_position,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be confusing to call the get_rope with config.max_position_embedding, while generating the rope cache with config.rope_scaling['original_max_position_embeddings']. It's better to directly pass the right max_position.

base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),

"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),

"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
Expand Down
Loading
Loading