Skip to content

Commit c483b1e

Browse files
hiworldwzjshihaobaishihaobai
authored
deepseekv3 bmm noquant and fix moe gemm bug. (#745)
Co-authored-by: shihaobai <baishihao@sensetime.com> Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com>
1 parent 808d832 commit c483b1e

File tree

5 files changed

+79
-13
lines changed

5 files changed

+79
-13
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,16 +73,20 @@ def _post_load_weights(self) -> None:
7373
and (not self.static_activation or self.input_scale is not None)
7474
):
7575
if self.weight_scale.ndim > 1:
76-
self.weight_scale = self.weight_scale.transpose(0, 1).cuda(self.device_id_)
76+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
77+
self.weight_scale = self.weight_scale.cuda(self.device_id_).transpose(0, 1)
7778
self.weight = [
78-
self.weight.transpose(0, 1).cuda(self.device_id_),
79+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
80+
self.weight.cuda(self.device_id_).transpose(0, 1),
7981
self.weight_scale,
8082
self.input_scale,
8183
]
8284
else:
8385
self.weight = self.quant_method.quantize(self.weight.to(self.data_type_).cuda(self.device_id_))
8486
return
85-
self.weight = self.weight.to(self.data_type_).transpose(0, 1).cuda(self.device_id_)
87+
88+
# 让 k dim 更连续,大多数split k 算法的算子可能能更快
89+
self.weight = self.weight.to(self.data_type_).cuda(self.device_id_).transpose(0, 1)
8690

8791

8892
class MMWeight(MMWeightTpl):

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def grouped_matmul_kernel(
331331
for step_k in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
332332
# hint to Triton compiler to do proper loop pipelining
333333
# tl.multiple_of(a_ptrs, [16, 16])
334-
tl.multiple_of(b_ptrs, [16, 16])
334+
# tl.multiple_of(b_ptrs, [16, 16])
335335

336336
if use_fp8_w8a8:
337337
a = tl.load(a_ptrs, mask=(offs_am[None, :] < cur_m) & (offs_k[:, None] < k))
@@ -464,10 +464,10 @@ def grouped_matmul(
464464
token_input_scale,
465465
expert_to_weights_scale,
466466
expert_to_weights_scale.stride(0)
467-
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
467+
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
468468
else 0,
469469
expert_to_weights_scale.stride(1)
470-
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
470+
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
471471
else 0,
472472
expert_to_weights_scale.stride(2)
473473
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3
@@ -532,10 +532,10 @@ def grouped_matmul(
532532
token_input_scale,
533533
expert_to_weights_scale,
534534
expert_to_weights_scale.stride(0)
535-
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
535+
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 1
536536
else 0,
537537
expert_to_weights_scale.stride(1)
538-
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 2
538+
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim >= 2
539539
else 0,
540540
expert_to_weights_scale.stride(2)
541541
if expert_to_weights_scale is not None and expert_to_weights_scale.ndim == 3

lightllm/common/quantization/vllm_quant.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_
198198
dtype=input_tensor.dtype,
199199
)
200200
else:
201-
qweight = qweight.t().contiguous().t()
202201
input_scale = input_scale.t().contiguous().t()
203202
torch.ops._C.cutlass_scaled_mm(out, qinput_tensor, qweight, input_scale, weight_scale, bias)
204203
return out

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
ROWBMMWeightNoTp,
1919
)
2020
from functools import partial
21+
from ..triton_kernel.weight_dequant import weight_dequant
2122

2223

2324
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
@@ -116,8 +117,15 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size):
116117
def load_hf_weights(self, weights):
117118
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
118119
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
120+
# for deepseek_v3, the bmm operator is not quantized
121+
if self.quant_cfg.quantized_weight:
122+
kv_b_proj_ = weight_dequant(
123+
kv_b_proj_.cuda(),
124+
weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix].cuda(),
125+
).cpu()
119126
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
120127
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)
128+
121129
if (
122130
self.quant_cfg.quantized_weight
123131
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
@@ -184,15 +192,11 @@ def _init_qkvo(self):
184192
f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight",
185193
self.data_type_,
186194
split_n_embed=self.tp_q_head_num_,
187-
weight_scale_suffix=self.weight_scale_suffix,
188-
act_scale_suffix=self.act_scale_suffix,
189195
)
190196
self.v_b_proj_ = ROWBMMWeight(
191197
f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight",
192198
self.data_type_,
193199
split_n_embed=self.tp_q_head_num_,
194-
weight_scale_suffix=self.weight_scale_suffix,
195-
act_scale_suffix=self.act_scale_suffix,
196200
)
197201
if self.enable_cc_method:
198202
self.cc_kv_b_proj_ = ROWMMWeight(
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# adapt from
2+
# https://github.com/deepseek-ai/DeepSeek-V3/blob/f09f5fa321f5a421704136c0463b1eaca6557712/inference/kernel.py
3+
import torch
4+
import triton
5+
import triton.language as tl
6+
from triton import Config
7+
8+
9+
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
10+
"""
11+
Dequantizes the given weight tensor using the provided scale tensor.
12+
13+
Args:
14+
x (torch.Tensor): The quantized weight tensor of shape (M, N).
15+
s (torch.Tensor): The scale tensor of shape (M, N).
16+
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
17+
18+
Returns:
19+
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
20+
21+
Raises:
22+
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
23+
"""
24+
assert x.is_contiguous() and s.is_contiguous(), "Input tensors must be contiguous"
25+
assert x.dim() == 2 and s.dim() == 2, "Input tensors must have 2 dimensions"
26+
M, N = x.size()
27+
y = torch.empty_like(x, dtype=torch.get_default_dtype())
28+
grid = lambda meta: (triton.cdiv(M, meta["BLOCK_SIZE"]), triton.cdiv(N, meta["BLOCK_SIZE"]))
29+
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
30+
return y.to(torch.bfloat16)
31+
32+
33+
@triton.jit
34+
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
35+
"""
36+
Dequantizes weights using the provided scaling factors and stores the result.
37+
38+
Args:
39+
x_ptr (tl.pointer): Pointer to the quantized weights.
40+
s_ptr (tl.pointer): Pointer to the scaling factors.
41+
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
42+
M (int): Number of rows in the weight matrix.
43+
N (int): Number of columns in the weight matrix.
44+
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
45+
46+
Returns:
47+
None
48+
"""
49+
pid_m = tl.program_id(axis=0)
50+
pid_n = tl.program_id(axis=1)
51+
n = tl.cdiv(N, BLOCK_SIZE)
52+
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
53+
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
54+
offs = offs_m[:, None] * N + offs_n[None, :]
55+
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
56+
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
57+
s = tl.load(s_ptr + pid_m * n + pid_n)
58+
y = x * s
59+
tl.store(y_ptr + offs, y, mask=mask)

0 commit comments

Comments
 (0)