Skip to content

Commit 808d832

Browse files
authored
Improve the accuracy of deepseekv3 (#744)
1 parent 00e4de7 commit 808d832

File tree

11 files changed

+143
-132
lines changed

11 files changed

+143
-132
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def _load_hf_weights_etp(self, weights):
167167
expert_gate_up_proj_last = None
168168
expert_down_proj_last = None
169169
if self.e_score_correction_bias_name in weights:
170-
self.e_score_correction_bias = self._cuda(self.e_score_correction_bias_name)
170+
self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name])
171171

172172
for i_experts_ep in range(n_expert_ep):
173173
expert_up_proj = None
@@ -223,6 +223,8 @@ def load_hf_weights(self, weights):
223223
if os.environ.get("ETP_MODE_ENABLED") == "true":
224224
self._load_hf_weights_etp(weights)
225225
else:
226+
if self.e_score_correction_bias_name in weights:
227+
self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name])
226228
for i_experts in range(self.n_routed_experts):
227229
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
228230
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,8 @@ def dequant_weight(self, weight: torch.Tensor, scale: torch.Tensor) -> torch.Ten
500500
weight = weight.to(self.data_type_)
501501
block_size = weight.shape[-1] // scale.shape[-1]
502502
w_shape = weight.shape
503-
scale = scale.unsqueeze(-1).repeat(1, 1, 1, block_size).reshape(w_shape[0], w_shape[1], -1)
503+
s_shape = scale.shape
504+
scale = scale.unsqueeze(-1).repeat(1, 1, 1, block_size).reshape(s_shape[0], s_shape[1], -1)
504505
scale = scale.unsqueeze(2).repeat(1, 1, block_size, 1).reshape(w_shape)
505506
return (weight * scale).to(self.data_type_)
506507

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from .moe_kernel_configs import MoeGroupedGemmKernelConfig
3434
from .moe_silu_and_mul import silu_and_mul_fwd
3535
from .moe_sum_reduce import moe_sum_reduce
36+
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
3637

3738
FFN_MOE_CHUNK_SIZE = 8 * 1024
3839

@@ -223,7 +224,7 @@ def grouped_matmul_kernel(
223224
n, # int
224225
expert_num, # int
225226
topk_num, # int
226-
token_scale_ptr, # [1,]
227+
token_scale_ptr, # [1,] for per tensor quant, or [token_num, hidden_dim // block_size] for per token, group quant
227228
weight_scale_ptr, # [expert_num,] or [export_num, n // block_size_n, k // block_size_k]
228229
weight_scale_stride0,
229230
weight_scale_stride1,
@@ -306,7 +307,7 @@ def grouped_matmul_kernel(
306307

307308
if use_fp8_w8a8:
308309
if block_size_k > 0 and block_size_n > 0:
309-
a_scale = tl.load(token_scale_ptr, eviction_policy="evict_last")
310+
a_scale_ptrs = token_scale_ptr + (a_m_index // topk_num) * (token_stride_0 // block_size_k)
310311
offs_bsn = offs_bn // block_size_n
311312
b_scale_ptrs = weight_scale_ptr + expert_id * weight_scale_stride0 + offs_bsn * weight_scale_stride1
312313
else:
@@ -342,8 +343,9 @@ def grouped_matmul_kernel(
342343
if use_fp8_w8a8:
343344
if block_size_k > 0 and block_size_n > 0:
344345
offs_ks = step_k * BLOCK_SIZE_K // block_size_k
346+
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=offs_am < cur_m, other=0.0)
345347
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
346-
accumulator += tl.dot(b, a) * a_scale * b_scale[:, None]
348+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
347349
else:
348350
accumulator = tl.dot(b, a, acc=accumulator)
349351
else:
@@ -387,6 +389,7 @@ def grouped_matmul(
387389
expert_token_limit: int,
388390
mul_routed_weight: bool,
389391
use_fp8_w8a8: bool,
392+
alloc_tensor_func=torch.empty,
390393
**run_config,
391394
):
392395
"""
@@ -417,7 +420,6 @@ def grouped_matmul(
417420
if expert_to_weights_scale.ndim == 3:
418421
block_size_n = expert_weights.shape[1] // expert_to_weights_scale.shape[1]
419422
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]
420-
421423
if not run_config:
422424
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
423425
M=token_inputs.shape[0],
@@ -436,8 +438,22 @@ def grouped_matmul(
436438
num_warps = run_config["num_warps"]
437439
num_stages = run_config["num_stages"]
438440

441+
if block_size_k != 0:
442+
# 如果使用了 block wise 量化,分块大小不能超过 block size
443+
BLOCK_SIZE_K = min(BLOCK_SIZE_K, block_size_k)
444+
assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)
445+
439446
if use_fp8_w8a8:
440-
token_inputs, token_input_scale = ops.scaled_fp8_quant(token_inputs, token_input_scale)
447+
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化
448+
if block_size_k == 0:
449+
token_inputs, token_input_scale = ops.scaled_fp8_quant(token_inputs, token_input_scale)
450+
else:
451+
_m, _k = token_inputs.shape
452+
assert _k % block_size_k == 0
453+
input_scale = alloc_tensor_func((_m, _k // block_size_k), dtype=torch.float32, device=token_inputs.device)
454+
qinput_tensor = alloc_tensor_func((_m, _k), dtype=expert_weights.dtype, device=token_inputs.device)
455+
per_token_group_quant_fp8(token_inputs, block_size_k, qinput_tensor, input_scale)
456+
token_inputs, token_input_scale = qinput_tensor, input_scale
441457

442458
kernel = grouped_matmul_kernel.warmup(
443459
expert_token_limit,
@@ -579,7 +595,6 @@ def fused_experts_impl(
579595
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
580596
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
581597
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
582-
583598
num_tokens, _ = hidden_states.shape
584599
E, N, _ = w1.shape
585600
CHUNK_SIZE = FFN_MOE_CHUNK_SIZE
@@ -632,6 +647,7 @@ def fused_experts_impl(
632647
expert_token_limit=2 ** 31 - 1,
633648
mul_routed_weight=False,
634649
use_fp8_w8a8=use_fp8_w8a8,
650+
alloc_tensor_func=alloc_tensor_func,
635651
**run_config,
636652
)
637653

@@ -650,6 +666,7 @@ def fused_experts_impl(
650666
expert_token_limit=2 ** 31 - 1,
651667
mul_routed_weight=True,
652668
use_fp8_w8a8=use_fp8_w8a8,
669+
alloc_tensor_func=alloc_tensor_func,
653670
**run_config,
654671
)
655672

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,22 @@
66

77

88
@triton.jit
9-
def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr):
9+
def _compare_and_swap(x, x_1, ids, flip, i: tl.core.constexpr, n_dims: tl.core.constexpr):
1010
n_outer: tl.core.constexpr = x.numel >> n_dims
1111
shape: tl.core.constexpr = [n_outer * 2 ** i, 2, 2 ** (n_dims - i - 1)]
1212
y = tl.core.reshape(x, shape)
13+
y_1 = tl.core.reshape(x_1, shape)
1314
# slice left/right with 'stride' 2**(n_dims - i - 1)
1415
mask = tl.core.arange(0, 2)[None, :, None]
1516
left = tl.core.broadcast_to(sum(y * (1 - mask), 1)[:, None, :], shape)
1617
right = tl.core.broadcast_to(sum(y * mask, 1)[:, None, :], shape)
1718
left = tl.core.reshape(left, x.shape)
1819
right = tl.core.reshape(right, x.shape)
1920

21+
left_1 = tl.core.broadcast_to(sum(y_1 * (1 - mask), 1)[:, None, :], shape)
22+
right_1 = tl.core.broadcast_to(sum(y_1 * mask, 1)[:, None, :], shape)
23+
left_1 = tl.core.reshape(left_1, x_1.shape)
24+
right_1 = tl.core.reshape(right_1, x_1.shape)
2025
# idx
2126
y_idx = tl.core.reshape(ids, shape)
2227
left_idx = tl.core.broadcast_to(sum(y_idx * (1 - mask), 1)[:, None, :], shape)
@@ -36,11 +41,18 @@ def _compare_and_swap(x, ids, flip, i: tl.core.constexpr, n_dims: tl.core.conste
3641

3742
new_ids = ids ^ tl.core.where(cond, left_idx ^ right_idx, zeros_like(ids))
3843

39-
return ret.to(x.dtype, bitcast=True), new_ids
44+
# swap x_1
45+
idtype_1 = tl.core.get_int_dtype(bitwidth=x_1.dtype.primitive_bitwidth, signed=True)
46+
ileft_1 = left_1.to(idtype_1, bitcast=True)
47+
iright_1 = right_1.to(idtype_1, bitcast=True)
48+
ix_1 = x_1.to(idtype, bitcast=True)
49+
ret_1 = ix_1 ^ tl.core.where(cond, ileft_1 ^ iright_1, zeros_like(ix_1))
50+
51+
return ret.to(x.dtype, bitcast=True), ret_1.to(x_1.dtype, bitcast=True), new_ids
4052

4153

4254
@triton.jit
43-
def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr):
55+
def _bitonic_merge(x, x_1, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n_dims: tl.core.constexpr):
4456
"""
4557
order_type 0 == ascending
4658
order_type 1 == descending
@@ -60,21 +72,21 @@ def _bitonic_merge(x, ids, stage: tl.core.constexpr, order: tl.core.constexpr, n
6072
flip = order
6173
# perform `stage` rounds of `compare-and-swap`
6274
for i in tl.core.static_range(stage):
63-
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
64-
return x, ids
75+
x, x_1, ids = _compare_and_swap(x, x_1, ids, flip, i + (n_dims - stage), n_dims)
76+
return x, x_1, ids
6577

6678

6779
@triton.jit
68-
def argsort(x, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0):
80+
def argsort(x, x_1, ids, dim: tl.core.constexpr = None, descending: tl.core.constexpr = tl.core.CONSTEXPR_0):
6981
# handle default dimension or check that it is the most minor dim
7082
_dim: tl.core.constexpr = len(x.shape) - 1 if dim is None else dim
7183
tl.core.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
7284
# iteratively run bitonic merge-sort steps
7385
n_dims: tl.core.constexpr = _log2(x.shape[_dim])
7486

7587
for i in tl.core.static_range(1, n_dims + 1):
76-
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
77-
return x, ids
88+
x, x_1, ids = _bitonic_merge(x, x_1, ids, i, 2 if i < n_dims else descending, n_dims)
89+
return x, x_1, ids
7890

7991

8092
@triton.jit
@@ -106,6 +118,7 @@ def grouped_topk_kernel(
106118
EXPERT_GROUP_NUM: tl.constexpr, # tl.next_power_two_of(group_num)
107119
EXPERT_GROUP_SIZE: tl.constexpr, # tl.next_power_two_of(group_expert_num)
108120
RENORMALIZE: tl.constexpr,
121+
GROUP_SCORE_USED_TOPK_NUM: tl.constexpr,
109122
):
110123
token_index = tl.program_id(axis=0)
111124
offs_n = tl.arange(0, EXPERT_BLOCK_SIZE)
@@ -115,12 +128,14 @@ def grouped_topk_kernel(
115128
other=-10000000.0,
116129
).to(tl.float32)
117130
if IS_SIGMOID:
118-
scores = tl.sigmoid(hidden_states)
131+
old_scores = tl.sigmoid(hidden_states)
119132
else:
120-
scores = tl.softmax(hidden_states)
133+
old_scores = tl.softmax(hidden_states)
121134

122135
if HAS_CORRECTION_BIAS:
123-
scores += tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0)
136+
scores = old_scores + tl.load(correction_bias_ptr + offs_n, mask=offs_n < total_expert_num, other=-10000000.0)
137+
else:
138+
scores = old_scores
124139

125140
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
126141
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
@@ -134,7 +149,15 @@ def grouped_topk_kernel(
134149
other=-10000000.0,
135150
) # [group, group_size]
136151

137-
group_value = tl.max(group_scores, axis=1) # [group,]
152+
group_value = tl.sum(
153+
tl.where(
154+
(offs_group < group_num)[:, None] & (offs_group_v < GROUP_SCORE_USED_TOPK_NUM)[None, :],
155+
tl.sort(group_scores, dim=1, descending=True),
156+
0.0,
157+
),
158+
axis=1,
159+
)
160+
138161
sorted_group_value = tl.sort(group_value, descending=True)
139162
group_topk_value = tl.sum(tl.where(offs_group == group_topk_num - 1, sorted_group_value, 0.0))
140163
mask_group_scores = tl.where(
@@ -155,7 +178,7 @@ def grouped_topk_kernel(
155178
mask_scores = tl.load(
156179
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
157180
)
158-
sorted_scores, sorted_indexes = argsort(mask_scores, offs_n, descending=True)
181+
_, sorted_scores, sorted_indexes = argsort(mask_scores, old_scores, offs_n, descending=True)
159182

160183
if RENORMALIZE:
161184
sum_scores = tl.sum(tl.where(offs_n < topk_num, sorted_scores, 0.0))
@@ -184,6 +207,7 @@ def triton_grouped_topk(
184207
num_expert_group: int = 0,
185208
topk_group: int = 0,
186209
scoring_func: str = "softmax",
210+
group_score_used_topk_num=2,
187211
):
188212

189213
if correction_bias is not None:
@@ -225,6 +249,7 @@ def triton_grouped_topk(
225249
EXPERT_GROUP_NUM=triton.next_power_of_2(num_expert_group),
226250
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
227251
RENORMALIZE=renormalize,
252+
GROUP_SCORE_USED_TOPK_NUM=group_score_used_topk_num,
228253
num_warps=1,
229254
num_stages=1,
230255
)

lightllm/common/fused_moe/topk_select.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def grouped_topk(
7070
scores = torch.sigmoid(gating_output)
7171
else:
7272
scores = torch.softmax(gating_output, dim=-1)
73-
73+
old_scores = scores
7474
if correction_bias is not None:
75-
scores.add_(correction_bias)
75+
scores = scores + correction_bias
7676

7777
num_token = scores.shape[0]
7878
group_scores = scores.view(num_token, num_expert_group, -1).max(dim=-1).values # [n, n_group]
@@ -85,7 +85,43 @@ def grouped_topk(
8585
.reshape(num_token, -1)
8686
) # [n, e]
8787
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
88-
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
88+
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
89+
topk_weights = old_scores.gather(1, topk_ids)
90+
if renormalize:
91+
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
92+
93+
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
94+
95+
96+
# biased_grouped_topk adapt from sgl-project/sglang/python/sglang/srt/layers/moe/topk.py
97+
def biased_grouped_topk(
98+
hidden_states: torch.Tensor,
99+
gating_output: torch.Tensor,
100+
correction_bias: torch.Tensor,
101+
topk: int,
102+
renormalize: bool,
103+
num_expert_group: int = 0,
104+
topk_group: int = 0,
105+
scoring_func: str = "sigmoid",
106+
):
107+
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
108+
scores = gating_output.sigmoid()
109+
num_token = scores.shape[0]
110+
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
111+
group_scores = (
112+
scores_for_choice.view(num_token, num_expert_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
113+
) # [n, n_group]
114+
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[1] # [n, top_k_group]
115+
group_mask = torch.zeros_like(group_scores) # [n, n_group]
116+
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
117+
score_mask = (
118+
group_mask.unsqueeze(-1)
119+
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
120+
.reshape(num_token, -1)
121+
) # [n, e]
122+
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
123+
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
124+
topk_weights = scores.gather(1, topk_ids)
89125

90126
if renormalize:
91127
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
@@ -161,6 +197,11 @@ def select_experts(
161197
scoring_func=scoring_func,
162198
)
163199
else:
200+
group_score_topk_num = 1
201+
# for deepseek v3
202+
if topk_group == 4 and num_expert_group == 8 and top_k == 8:
203+
group_score_topk_num = 2
204+
164205
topk_weights, topk_ids = triton_grouped_topk(
165206
hidden_states=hidden_states,
166207
gating_output=router_logits,
@@ -170,7 +211,9 @@ def select_experts(
170211
num_expert_group=num_expert_group,
171212
topk_group=topk_group,
172213
scoring_func=scoring_func,
214+
group_score_used_topk_num=group_score_topk_num,
173215
)
216+
174217
elif custom_routing_function is None:
175218
topk_weights, topk_ids = fused_topk(
176219
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
3232
self.tp_v_head_num_ = 1
3333
self.qk_nope_head_dim = network_config["qk_nope_head_dim"]
3434
self.qk_rope_head_dim = network_config["qk_rope_head_dim"]
35+
self.v_head_dim = network_config["v_head_dim"]
3536
self.q_lora_rank = network_config["q_lora_rank"]
3637
self.kv_lora_rank = network_config["kv_lora_rank"]
3738

@@ -196,16 +197,12 @@ def _decompress_kv(
196197

197198
# CC
198199
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous()
199-
k_nope = self.alloc_tensor(
200-
[compressed_kv.shape[0], self.tp_q_head_num_, self.qk_nope_head_dim],
200+
kv_nope = self.alloc_tensor(
201+
[compressed_kv.shape[0], self.tp_q_head_num_, (self.qk_nope_head_dim + self.v_head_dim)],
201202
dtype=compressed_kv.dtype,
202203
)
203-
v = self.alloc_tensor(
204-
k_nope.shape,
205-
dtype=compressed_kv.dtype,
206-
)
207-
layer_weight.cc_k_b_proj_.mm(compressed_kv, out=k_nope.reshape(compressed_kv.shape[0], -1))
208-
layer_weight.cc_v_b_proj_.mm(compressed_kv, out=v.reshape(compressed_kv.shape[0], -1))
204+
layer_weight.cc_kv_b_proj_.mm(compressed_kv, out=kv_nope.reshape(compressed_kv.shape[0], -1))
205+
k_nope, v = torch.split(kv_nope, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
209206
return k_nope, k_rope, v
210207

211208
def _context_attention_kernel_with_CC(

0 commit comments

Comments
 (0)