Skip to content

Commit 92c274a

Browse files
author
niushengxiao
committed
fix: replace single float with two floats for per tensor quant
1 parent 1600ad0 commit 92c274a

File tree

5 files changed

+594
-498
lines changed

5 files changed

+594
-498
lines changed

lightllm/common/basemodel/triton_kernel/destindex_copy_kv_fp8.py

Lines changed: 19 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -42,43 +42,6 @@ def _fwd_kernel_destindex_copy_kv_per_head_fp8(
4242
return
4343

4444

45-
@triton.jit
46-
def _fwd_kernel_destindex_copy_kv_per_tensor_fp8(
47-
K,
48-
Dest_loc,
49-
Out,
50-
scalar_scale,
51-
stride_k_bs,
52-
stride_k_h,
53-
stride_k_d,
54-
stride_o_bs,
55-
stride_o_h,
56-
stride_o_d,
57-
head_num,
58-
BLOCK_DMODEL: tl.constexpr,
59-
BLOCK_HEAD: tl.constexpr,
60-
FP8_MIN: tl.constexpr,
61-
FP8_MAX: tl.constexpr,
62-
):
63-
cur_index = tl.program_id(0)
64-
offs_h = tl.arange(0, BLOCK_HEAD)
65-
offs_d = tl.arange(0, BLOCK_DMODEL)
66-
67-
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
68-
69-
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
70-
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
71-
72-
scale = tl.load(scalar_scale)
73-
74-
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
75-
k_scale = k / scale
76-
k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
77-
78-
tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num)
79-
return
80-
81-
8245
@torch.no_grad()
8346
def destindex_copy_kv_fp8(K, DestLoc, scales, Out):
8447
if scales is None:
@@ -93,47 +56,25 @@ def destindex_copy_kv_fp8(K, DestLoc, scales, Out):
9356
grid = (seq_len,)
9457
num_warps = 1
9558

96-
if scales.dim() == 0:
97-
_fwd_kernel_destindex_copy_kv_per_tensor_fp8[grid](
98-
K,
99-
DestLoc,
100-
Out,
101-
scales,
102-
K.stride(0),
103-
K.stride(1),
104-
K.stride(2),
105-
Out.stride(0),
106-
Out.stride(1),
107-
Out.stride(2),
108-
head_num,
109-
BLOCK_DMODEL=head_dim,
110-
BLOCK_HEAD=BLOCK_HEAD,
111-
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
112-
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
113-
num_warps=num_warps,
114-
num_stages=1,
115-
)
116-
else:
117-
_fwd_kernel_destindex_copy_kv_per_head_fp8[grid](
118-
K,
119-
DestLoc,
120-
Out,
121-
scales,
122-
K.stride(0),
123-
K.stride(1),
124-
K.stride(2),
125-
Out.stride(0),
126-
Out.stride(1),
127-
Out.stride(2),
128-
head_num,
129-
BLOCK_DMODEL=head_dim,
130-
BLOCK_HEAD=BLOCK_HEAD,
131-
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
132-
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
133-
num_warps=num_warps,
134-
num_stages=1,
135-
)
136-
return
59+
_fwd_kernel_destindex_copy_kv_per_head_fp8[grid](
60+
K,
61+
DestLoc,
62+
Out,
63+
scales,
64+
K.stride(0),
65+
K.stride(1),
66+
K.stride(2),
67+
Out.stride(0),
68+
Out.stride(1),
69+
Out.stride(2),
70+
head_num,
71+
BLOCK_DMODEL=head_dim,
72+
BLOCK_HEAD=BLOCK_HEAD,
73+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
74+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
75+
num_warps=num_warps,
76+
num_stages=1,
77+
)
13778

13879

13980
if __name__ == "__main__":

lightllm/common/mem_manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ def __init__(self, layer_num, head_num):
2727
self.qmax = torch.finfo(torch.float8_e4m3fn).max
2828
self.model_arch = get_model_architectures(get_env_start_args().model_dir)
2929
self.layer_num = layer_num
30+
self.head_num = head_num
3031
self.total_head_num = head_num * dist.get_world_size() if dist.is_initialized() else head_num
31-
self.scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num]
32+
self.scales_shape = [layer_num, 2 * head_num] if get_env_start_args().enable_fa3 else [layer_num, 2]
3233
self.scales = None
3334
self.scales_list = []
3435
self.abs_max = None
@@ -62,9 +63,11 @@ def __init__(self, layer_num, head_num):
6263
f"not match current model head num {self.total_head_num}"
6364
)
6465
if get_env_start_args().enable_fa3:
65-
assert len(cfg["scales_shape"]) == 2, "this config is not for fa3 backend"
66+
if cfg["quant_type"] != "per_head":
67+
raise ValueError(f"quant type {cfg['num_head']} in config not match fa3 backend")
6668
else:
67-
assert len(cfg["scales_shape"]) == 1, "this config is not for flashinfer backend"
69+
if cfg["quant_type"] != "per_tensor":
70+
raise ValueError(f"quant type {cfg['quant_type']} in config not match flashinfer backend")
6871

6972
self.qmin = cfg["qmin"]
7073
self.qmax = cfg["qmax"]
@@ -73,6 +76,8 @@ def __init__(self, layer_num, head_num):
7376
full_scales_list = cfg["scales"]
7477
self.scales_list = full_scales_list
7578
self.scales = torch.tensor(self.scales_list, dtype=torch.float32, device="cuda").view(self.scales_shape)
79+
if not get_env_start_args().enable_fa3:
80+
self.scales = torch.repeat_interleave(self.scales, self.head_num, dim=-1)
7681
if get_env_start_args().enable_fa3 and dist.is_initialized() and dist.get_world_size() > 1:
7782
half_head = self.total_head_num // 2
7883
start_head = dist.get_rank() * head_num
@@ -103,7 +108,9 @@ def update_calibration_data(self, kv_buffer: torch.Tensor, layer_index: int):
103108
if get_env_start_args().enable_fa3:
104109
kv_max = kv_buffer.abs().amax(dim=(0, 2)).to(torch.float32)
105110
else:
106-
kv_max = kv_buffer.abs().amax(dim=()).to(torch.float32)
111+
k_max = kv_buffer[:, : self.head_num, :].abs().amax(dim=()).to(torch.float32)
112+
v_max = kv_buffer[:, self.head_num :, :].abs().amax(dim=()).to(torch.float32)
113+
kv_max = torch.tensor([k_max, v_max], device="cuda", dtype=torch.float32)
107114
self.abs_max[layer_index] = torch.maximum(self.abs_max[layer_index], kv_max)
108115
if self.count == self.warmup_counts + self.inference_counts - 1 and layer_index == self.layer_num - 1:
109116
final_abs_max = self.abs_max
@@ -136,6 +143,7 @@ def _export_calibration_data(self):
136143
cfg = {
137144
"version": "1.0",
138145
"architectures": self.model_arch,
146+
"quant_type": "per_head" if get_env_start_args().enable_fa3 else "per_tensor",
139147
"qmin": self.qmin,
140148
"qmax": self.qmax,
141149
"num_layers": self.layer_num,

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,8 +225,8 @@ def _context_attention_flashinfer_kernel_fp8(
225225
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
226226
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
227227
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales_list
228-
k_descale = offline_scales[self.layer_num_] if offline_scales is not None else None
229-
v_descale = offline_scales[self.layer_num_] if offline_scales is not None else None
228+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
229+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
230230
infer_state.prefill_wrapper.run(
231231
q.view(q.shape[0], -1, self.head_dim_),
232232
(k, v),
@@ -517,8 +517,8 @@ def _token_decode_attention_flashinfer_fp8(self, q, infer_state: LlamaFlashInfer
517517
k = kv[:, :, : self.tp_k_head_num_, :].view(torch.float8_e4m3fn)
518518
v = kv[:, :, self.tp_k_head_num_ :, :].view(torch.float8_e4m3fn)
519519
offline_scales = infer_state.mem_manager.offline_fp8_quant_manager.scales_list
520-
k_descale = offline_scales[self.layer_num_] if offline_scales is not None else None
521-
v_descale = offline_scales[self.layer_num_] if offline_scales is not None else None
520+
k_descale = offline_scales[self.layer_num_][0] if offline_scales is not None else None
521+
v_descale = offline_scales[self.layer_num_][1] if offline_scales is not None else None
522522
infer_state.decode_wrapper.run(
523523
q.view(calcu_shape1),
524524
(k, v),

0 commit comments

Comments
 (0)