From f49777ba62b4926d0f8c100ab06edb03c5c10098 Mon Sep 17 00:00:00 2001 From: Simon Mo Date: Thu, 26 Dec 2024 16:09:44 -0800 Subject: [PATCH] Deepseek v3 (#11502) Signed-off-by: mgoin Co-authored-by: mgoin Co-authored-by: robertgshaw2-neuralmagic --- csrc/moe/moe_align_sum_kernels.cu | 160 ++++- vllm/config.py | 11 +- .../layers/fused_moe/fused_moe.py | 17 +- vllm/model_executor/layers/fused_moe/layer.py | 102 ++- .../model_executor/layers/quantization/fp8.py | 7 +- vllm/model_executor/models/deepseek_v3.py | 650 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 7 files changed, 887 insertions(+), 61 deletions(-) create mode 100644 vllm/model_executor/models/deepseek_v3.py diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index fff7ce34c838a..24341d63fb1f8 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -113,6 +113,92 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, } } +// TODO(simon): this is temporarily adapted from +// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 +// we did this to unblock Deepseek V3 but there should be a better +// implementation to manage shared memory. +template +__global__ void moe_align_block_size_global_mem_kernel( + scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, + int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + if (threadIdx.x < num_experts) { + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size) * + block_size; + } + *total_tokens_post_pad = cumsum[num_experts]; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} + template __global__ void moe_sum_kernel( scalar_t* __restrict__ out, // [..., d] @@ -137,25 +223,61 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - VLLM_DISPATCH_INTEGRAL_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem = - ((num_thread + 1) * num_experts + (num_experts + 1)) * - sizeof(int32_t); - - // set dynamic shared mem - auto kernel = vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem)); - kernel<<<1, num_thread, shared_mem, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); + + // If we have very large number of experts, we can no longer use shared + // memory. + // TODO(simon): the right solution should be calculating the exact right + // amount of shared memory and use that. The num_experts >= 256 is just a + // temporary solution to unblock Deepseek V3. + if (num_experts >= 256) { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + + const int32_t mem_tokens_cnts = + ((num_experts + 1) * num_experts) * sizeof(int32_t); + const int32_t mem_cumsum = (num_experts + 1) * sizeof(int32_t); + // allocate global memory + int32_t* tokens_cnts; + int32_t* cumsum; + cudaMalloc(&tokens_cnts, mem_tokens_cnts); + cudaMalloc(&cumsum, mem_cumsum); + + auto kernel = + vllm::moe::moe_align_block_size_global_mem_kernel; + kernel<<<1, num_thread, 0, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel(), tokens_cnts, cumsum); + cudaFree(tokens_cnts); + cudaFree(cumsum); + }); + } else { + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem = + ((num_thread + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<<1, num_thread, shared_mem, stream>>>( + topk_ids.data_ptr(), + sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); + } } void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] diff --git a/vllm/config.py b/vllm/config.py index 58649236b4225..ac767bbe14be4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -596,6 +596,12 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + if (self.hf_config.model_type == 'deepseek_v3' + and not self.enforce_eager): + logger.warning("CUDA graph is not supported for Deepseek V3 yet, " + "fallback to the eager mode.") + self.enforce_eager = True + def _verify_bnb_config(self) -> None: """ The current version of bitsandbytes (0.44.0) with 8-bit models does not @@ -712,8 +718,9 @@ def get_hidden_size(self) -> int: def get_head_size(self) -> int: # TODO remove hard code - if hasattr(self.hf_text_config, "model_type" - ) and self.hf_text_config.model_type == 'deepseek_v2': + if hasattr(self.hf_text_config, + "model_type") and (self.hf_text_config.model_type + in ('deepseek_v2', 'deepseek_v3')): # FlashAttention supports only head_size 32, 64, 128, 256, # we need to pad head_size 192 to 256 return 256 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 92e9ba3c9cebd..4101facbe7874 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -476,18 +476,29 @@ def fused_topk( return topk_weights, topk_ids -# This is used by the Deepseek-V2 model +# This is used by the Deepseek-V2 and Deepseek-V3 model def grouped_topk(hidden_states: torch.Tensor, gating_output: torch.Tensor, topk: int, renormalize: bool, num_expert_group: int = 0, - topk_group: int = 0): + topk_group: int = 0, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None): assert hidden_states.shape[0] == gating_output.shape[0], ( "Number of tokens mismatch") - scores = torch.softmax(gating_output, dim=-1) + if scoring_func == "softmax": + scores = torch.softmax(gating_output, dim=-1) + elif scoring_func == "sigmoid": + scores = gating_output.sigmoid() + else: + raise ValueError(f"Unsupported scoring function: {scoring_func}") + + if e_score_correction_bias is not None: + scores.add_(e_score_correction_bias.unsqueeze(0)) + num_token = scores.shape[0] group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group] diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 55c0a202920ff..01ffac4550f28 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -73,16 +73,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, set_weight_attrs(w2_weight, extra_weight_attrs) def apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -92,19 +94,23 @@ def apply( use_grouped_topk=use_grouped_topk, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) def forward_cuda( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, @@ -114,7 +120,9 @@ def forward_cuda( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -128,21 +136,29 @@ def forward_cpu(self, *args, **kwargs): "The CPU backend currently does not support MoE.") def forward_tpu( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None ) -> torch.Tensor: assert not use_grouped_topk assert num_expert_group is None assert topk_group is None assert custom_routing_function is None + if scoring_func != "softmax": + raise NotImplementedError( + "Only softmax scoring function is supported for TPU.") + if e_score_correction_bias is not None: + raise NotImplementedError( + "Expert score correction bias is not supported for TPU.") return fused_moe_pallas(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -156,7 +172,7 @@ def forward_tpu( class FusedMoE(torch.nn.Module): """FusedMoE layer for MoE models. - This layer contains both MergedColumnParallel weights (gate_up_proj / + This layer contains both MergedColumnParallel weights (gate_up_proj / w13) and RowParallelLinear weights (down_proj/ w2). Note: Mixtral uses w1, w2, and w3 for gate, up, and down_proj. We @@ -190,6 +206,8 @@ def __init__( tp_size: Optional[int] = None, prefix: str = "", custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ): super().__init__() @@ -210,6 +228,12 @@ def __init__( self.num_expert_group = num_expert_group self.topk_group = topk_group self.custom_routing_function = custom_routing_function + self.scoring_func = scoring_func + self.e_score_correction_bias = e_score_correction_bias + + if self.scoring_func != "softmax" and not self.use_grouped_topk: + raise ValueError("Only softmax scoring function is supported for " + "non-grouped topk.") if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -446,7 +470,9 @@ def select_experts(hidden_states: torch.Tensor, renormalize: bool, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None): + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -460,7 +486,9 @@ def select_experts(hidden_states: torch.Tensor, topk=top_k, renormalize=renormalize, num_expert_group=num_expert_group, - topk_group=topk_group) + topk_group=topk_group, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, @@ -489,7 +517,9 @@ def forward(self, hidden_states: torch.Tensor, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, num_expert_group=self.num_expert_group, - custom_routing_function=self.custom_routing_function) + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5dfd86727a02a..4362468c1db69 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -605,6 +605,8 @@ def apply( topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -617,7 +619,10 @@ def apply( renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function) + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias, + ) return fused_experts( x, diff --git a/vllm/model_executor/models/deepseek_v3.py b/vllm/model_executor/models/deepseek_v3.py new file mode 100644 index 0000000000000..333dc019b4d99 --- /dev/null +++ b/vllm/model_executor/models/deepseek_v3.py @@ -0,0 +1,650 @@ +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeepseekV3 model.""" +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class DeepseekV3MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekV3MoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + self.routed_scaling_factor = config.routed_scaling_factor + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = FusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + self.shared_experts = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=False, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits) * self.routed_scaling_factor + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + import math + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV3Attention(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: int, + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + self.num_heads = num_heads + tp_size = get_tensor_model_parallel_world_size() + assert num_heads % tp_size == 0 + self.num_local_heads = num_heads // tp_size + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + # O projection. + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # self.attn = Attention(self.num_heads, + # self.qk_head_dim, + # self.scaling, + # num_kv_heads=self.num_heads) + + # TODO, support head_size 192 + self.attn = Attention(self.num_local_heads, + 256, + self.scaling, + num_kv_heads=self.num_local_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if self.q_lora_rank is not None: + q = self.q_a_proj(hidden_states)[0] + q = self.q_a_layernorm(q) + q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + else: + q = self.q_proj(hidden_states)[0].view(-1, self.num_local_heads, + self.qk_head_dim) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + kv_a, _ = latent_cache.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + latent_cache = latent_cache.unsqueeze(1) + kv_a = self.kv_a_layernorm(kv_a.contiguous()) + kv = self.kv_b_proj(kv_a)[0] + kv = kv.view(-1, self.num_local_heads, + self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_pe = latent_cache[:, :, self.kv_lora_rank:] + q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + q[..., self.qk_nope_head_dim:] = q_pe + k = torch.empty_like(q) + k[..., :self.qk_nope_head_dim] = k_nope + k[..., self.qk_nope_head_dim:] = k_pe + q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], + value=0).view(-1, + self.num_local_heads * 256) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = attn_output.view( + -1, self.num_local_heads, 256)[..., :self.v_head_dim].reshape( + -1, self.num_local_heads * self.v_head_dim) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekV3DecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.self_attn = DeepseekV3Attention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV3MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV3MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +# TODO(simon): check whether we support torch compile for Deepseek V3 +# @support_torch_compile +class DeepseekV3Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV3DecoderLayer( + config, + prefix, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekV3ForCausalLM(nn.Module, SupportsPP): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = DeepseekV3Model(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # TODO(simon): support nextn predict layers + if self.config.num_nextn_predict_layers > 0: + assert self.config.num_nextn_predict_layers == 1 + layer_idx = self.config.num_hidden_layers + if name.startswith(f"model.layers.{layer_idx}"): + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + if name not in params_dict: + for key in params_dict: + print(key) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index b32a3421d5841..feb33bb373c3e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -45,6 +45,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "DeepseekV3ForCausalLM": ("deepseek_v3", "DeepseekV3ForCausalLM"), "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),