Skip to content

Commit 862dd76

Browse files
authored
Support NextN (MTP) speculative decoding for DeepSeek-V3/R1 (#3582)
1 parent fb4c9c3 commit 862dd76

File tree

7 files changed

+437
-7
lines changed

7 files changed

+437
-7
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
if (
9999
"DeepseekV2ForCausalLM" in self.hf_config.architectures
100100
or "DeepseekV3ForCausalLM" in self.hf_config.architectures
101+
or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
101102
):
102103
self.head_dim = 256
103104
self.attention_arch = AttentionArch.MLA
Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
# Copyright 2023-2024 SGLang Team
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
# ==============================================================================
14+
15+
"""Inference-only DeepSeek NextN Speculative Decoding."""
16+
from typing import Iterable, Optional, Tuple
17+
18+
import torch
19+
from torch import nn
20+
from transformers import PretrainedConfig
21+
from vllm import _custom_ops as ops
22+
23+
from sglang.srt.layers.layernorm import RMSNorm
24+
from sglang.srt.layers.linear import ReplicatedLinear
25+
from sglang.srt.layers.logits_processor import LogitsProcessor
26+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
27+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
28+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
29+
from sglang.srt.layers.quantization.fp8_utils import (
30+
block_quant_to_tensor_quant,
31+
normalize_e4m3fn_to_e4m3fnuz,
32+
)
33+
from sglang.srt.layers.vocab_parallel_embedding import (
34+
ParallelLMHead,
35+
VocabParallelEmbedding,
36+
)
37+
from sglang.srt.managers.schedule_batch import global_server_args_dict
38+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39+
from sglang.srt.model_loader.weight_utils import default_weight_loader
40+
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
41+
from sglang.srt.utils import is_hip
42+
43+
is_hip_ = is_hip()
44+
45+
46+
class DeepseekModelNextN(nn.Module):
47+
def __init__(
48+
self,
49+
config: PretrainedConfig,
50+
quant_config: Optional[QuantizationConfig] = None,
51+
) -> None:
52+
super().__init__()
53+
self.vocab_size = config.vocab_size
54+
55+
self.embed_tokens = VocabParallelEmbedding(
56+
config.vocab_size,
57+
config.hidden_size,
58+
enable_tp=not global_server_args_dict["enable_dp_attention"],
59+
)
60+
61+
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62+
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
63+
64+
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
65+
66+
self.decoder = DeepseekV2DecoderLayer(
67+
config, 0, quant_config=quant_config, is_nextn=True
68+
)
69+
70+
self.shared_head = nn.Module()
71+
self.shared_head.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
72+
73+
def forward(
74+
self,
75+
input_ids: torch.Tensor,
76+
positions: torch.Tensor,
77+
forward_batch: ForwardBatch,
78+
input_embeds: torch.Tensor = None,
79+
) -> torch.Tensor:
80+
if input_embeds is None:
81+
hidden_states = self.embed_tokens(input_ids)
82+
else:
83+
hidden_states = input_embeds
84+
85+
hidden_states = self.eh_proj(
86+
torch.cat(
87+
(
88+
self.enorm(hidden_states),
89+
self.hnorm(forward_batch.spec_info.hidden_states),
90+
),
91+
dim=-1,
92+
)
93+
)
94+
95+
residual = None
96+
hidden_states, residual = self.decoder(
97+
positions, hidden_states, forward_batch, residual
98+
)
99+
100+
if not forward_batch.forward_mode.is_idle():
101+
hidden_states, _ = self.shared_head.norm(hidden_states, residual)
102+
return hidden_states
103+
104+
105+
class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
106+
107+
def __init__(
108+
self,
109+
config: PretrainedConfig,
110+
quant_config: Optional[QuantizationConfig] = None,
111+
) -> None:
112+
nn.Module.__init__(self)
113+
self.config = config
114+
self.quant_config = quant_config
115+
116+
self.model = DeepseekModelNextN(config, quant_config)
117+
118+
if global_server_args_dict["enable_dp_attention"]:
119+
self.model.shared_head.head = ReplicatedLinear(
120+
config.hidden_size,
121+
config.vocab_size,
122+
bias=False,
123+
)
124+
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
125+
else:
126+
self.model.shared_head.head = ParallelLMHead(
127+
config.vocab_size,
128+
config.hidden_size,
129+
quant_config=quant_config,
130+
)
131+
self.logits_processor = LogitsProcessor(config)
132+
133+
@torch.no_grad()
134+
def forward(
135+
self,
136+
input_ids: torch.Tensor,
137+
positions: torch.Tensor,
138+
forward_batch: ForwardBatch,
139+
) -> torch.Tensor:
140+
hidden_states = self.model(input_ids, positions, forward_batch)
141+
return self.logits_processor(
142+
input_ids, hidden_states, self.model.shared_head.head, forward_batch
143+
)
144+
145+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
146+
if hasattr(self.config, "num_nextn_predict_layers"):
147+
num_nextn_layers = self.config.num_nextn_predict_layers
148+
assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
149+
assert num_nextn_layers == self.config.num_hidden_layers
150+
else:
151+
raise ValueError("num_nextn_predict_layers is not in the config")
152+
153+
stacked_params_mapping = [
154+
# (param_name, shard_name, shard_id)
155+
("gate_up_proj", "gate_proj", 0),
156+
("gate_up_proj", "up_proj", 1),
157+
]
158+
159+
# Params for weights, fp8 weight scales, fp8 activation scales
160+
# (param_name, weight_name, expert_id, shard_id)
161+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
162+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
163+
ckpt_gate_proj_name="gate_proj",
164+
ckpt_down_proj_name="down_proj",
165+
ckpt_up_proj_name="up_proj",
166+
num_experts=self.config.n_routed_experts,
167+
)
168+
169+
nextn_layer_prefix = "model.layers.0"
170+
nextn_spec_weight_names = [
171+
"shared_head.head",
172+
"shared_head.norm",
173+
"eh_proj",
174+
"embed_tokens",
175+
"enorm",
176+
"hnorm",
177+
]
178+
179+
params_dict = dict(self.named_parameters())
180+
for name, loaded_weight in weights:
181+
if not name.startswith(nextn_layer_prefix):
182+
continue
183+
else:
184+
is_decoder = True
185+
# For nextn specific weights
186+
for weight_name in nextn_spec_weight_names:
187+
if weight_name in name:
188+
name = name.replace(nextn_layer_prefix, "model")
189+
is_decoder = False
190+
break
191+
# For decoder layer weights
192+
if is_decoder:
193+
name = name.replace(nextn_layer_prefix, "model.decoder")
194+
195+
if "rotary_emb.inv_freq" in name:
196+
continue
197+
for param_name, weight_name, shard_id in stacked_params_mapping:
198+
# Skip non-stacked layers and experts (experts handled below).
199+
if weight_name not in name:
200+
continue
201+
# We have mlp.experts[0].gate_proj in the checkpoint.
202+
# Since we handle the experts below in expert_params_mapping,
203+
# we need to skip here BEFORE we update the name, otherwise
204+
# name will be updated to mlp.experts[0].gate_up_proj, which
205+
# will then be updated below in expert_params_mapping
206+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
207+
if ("mlp.experts." in name) and name not in params_dict:
208+
continue
209+
name = name.replace(weight_name, param_name)
210+
# Skip loading extra bias for GPTQ models.
211+
if name.endswith(".bias") and name not in params_dict:
212+
continue
213+
param = params_dict[name]
214+
weight_loader = param.weight_loader
215+
weight_loader(param, loaded_weight, shard_id)
216+
break
217+
else:
218+
for mapping in expert_params_mapping:
219+
param_name, weight_name, expert_id, shard_id = mapping
220+
if weight_name not in name:
221+
continue
222+
name = name.replace(weight_name, param_name)
223+
param = params_dict[name]
224+
weight_loader = param.weight_loader
225+
weight_loader(
226+
param,
227+
loaded_weight,
228+
name,
229+
shard_id=shard_id,
230+
expert_id=expert_id,
231+
)
232+
break
233+
else:
234+
# Skip loading extra bias for GPTQ models.
235+
if name.endswith(".bias") and name not in params_dict:
236+
continue
237+
238+
param = params_dict[name]
239+
weight_loader = getattr(
240+
param, "weight_loader", default_weight_loader
241+
)
242+
weight_loader(param, loaded_weight)
243+
244+
if not global_server_args_dict["disable_mla"]:
245+
self_attn = self.model.decoder.self_attn
246+
if hasattr(self_attn.kv_b_proj, "qweight"):
247+
# AWQ compatible
248+
w = ops.awq_dequantize(
249+
self_attn.kv_b_proj.qweight,
250+
self_attn.kv_b_proj.scales,
251+
self_attn.kv_b_proj.qzeros,
252+
0,
253+
0,
254+
0,
255+
).T
256+
else:
257+
w = self_attn.kv_b_proj.weight
258+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
259+
# This may affect the accuracy of fp8 model.
260+
if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
261+
torch.float8_e4m3fn,
262+
torch.float8_e4m3fnuz,
263+
):
264+
weight_block_size = self.quant_config.weight_block_size
265+
if weight_block_size is not None:
266+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
267+
if is_hip_:
268+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
269+
weight=w,
270+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
271+
input_scale=None,
272+
)
273+
else:
274+
weight = w
275+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
276+
277+
w, scale = block_quant_to_tensor_quant(
278+
weight, weight_scale, weight_block_size
279+
)
280+
self_attn.w_scale = scale
281+
w_kc, w_vc = w.unflatten(
282+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
283+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
284+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
285+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
286+
if (
287+
hasattr(self_attn.kv_b_proj, "weight_scale")
288+
and self_attn.w_scale is None
289+
):
290+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
291+
if is_hip_:
292+
self_attn.w_scale *= 2.0
293+
294+
295+
EntryClass = [DeepseekV3ForCausalLMNextN]

python/sglang/srt/models/deepseek_v2.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,8 @@ def forward(
519519
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
520520
if (
521521
forward_batch.forward_mode.is_extend()
522+
and not forward_batch.forward_mode.is_target_verify()
523+
and not forward_batch.forward_mode.is_draft_extend()
522524
and forward_batch.extend_prefix_lens.sum() == 0
523525
):
524526
return self.forward_normal(positions, hidden_states, forward_batch)
@@ -680,6 +682,7 @@ def __init__(
680682
config: PretrainedConfig,
681683
layer_id: int,
682684
quant_config: Optional[QuantizationConfig] = None,
685+
is_nextn: bool = False,
683686
) -> None:
684687
super().__init__()
685688
self.hidden_size = config.hidden_size
@@ -731,7 +734,7 @@ def __init__(
731734
quant_config=quant_config,
732735
layer_id=layer_id,
733736
)
734-
if (
737+
if is_nextn or (
735738
config.n_routed_experts is not None
736739
and layer_id >= config.first_k_dense_replace
737740
and layer_id % config.moe_layer_freq == 0

python/sglang/srt/server_args.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,14 +262,17 @@ def __post_init__(self):
262262
)
263263

264264
# Speculative Decoding
265-
if self.speculative_algorithm == "EAGLE":
265+
if (
266+
self.speculative_algorithm == "EAGLE"
267+
or self.speculative_algorithm == "NEXTN"
268+
):
266269
self.prefill_only_one_req = True
267270
self.disable_cuda_graph_padding = True
268271
self.disable_radix_cache = True
269272
self.disable_overlap_schedule = True
270273
self.chunked_prefill_size = -1
271274
logger.info(
272-
"The radix cache, chunked prefill, and overlap scheduler are disabled because of using eagle speculative decoding."
275+
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
273276
)
274277

275278
# GGUF
@@ -705,7 +708,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
705708
parser.add_argument(
706709
"--speculative-algorithm",
707710
type=str,
708-
choices=["EAGLE"],
711+
choices=["EAGLE", "NEXTN"],
709712
help="Speculative algorithm.",
710713
)
711714
parser.add_argument(

python/sglang/srt/speculative/eagle_worker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
fast_topk,
2525
select_top_k_tokens,
2626
)
27+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -57,11 +58,15 @@ def __init__(
5758
# Parse arguments
5859
self.topk = server_args.speculative_eagle_topk
5960
self.speculative_num_steps = server_args.speculative_num_steps
61+
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
62+
server_args.speculative_algorithm
63+
)
6064
self.server_args = server_args
6165

6266
# Share the embedding and lm_head
63-
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
64-
self.model_runner.model.set_embed_and_head(embed, head)
67+
if not self.speculative_algorithm.is_nextn():
68+
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
69+
self.model_runner.model.set_embed_and_head(embed, head)
6570
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
6671

6772
# Create multi-step attn backends and cuda graph runners

0 commit comments

Comments
 (0)