Skip to content

Commit 2f7b51f

Browse files
author
niushengxiao
committed
feat: kv fp8 quant calibration for fa3 and flashinfer
1 parent 58b7fd4 commit 2f7b51f

20 files changed

+1937
-26
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_destindex_copy_kv_per_head_fp8(
9+
K,
10+
Dest_loc,
11+
Out,
12+
scale,
13+
stride_k_bs,
14+
stride_k_h,
15+
stride_k_d,
16+
stride_o_bs,
17+
stride_o_h,
18+
stride_o_d,
19+
head_num,
20+
BLOCK_DMODEL: tl.constexpr,
21+
BLOCK_HEAD: tl.constexpr,
22+
FP8_MIN: tl.constexpr,
23+
FP8_MAX: tl.constexpr,
24+
):
25+
cur_index = tl.program_id(0)
26+
offs_h = tl.arange(0, BLOCK_HEAD)
27+
offs_d = tl.arange(0, BLOCK_DMODEL)
28+
29+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
30+
31+
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
32+
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
33+
34+
# to fp8
35+
scale_ptrs = scale + offs_h
36+
scales = tl.load(scale_ptrs, mask=offs_h < head_num, other=1.0)
37+
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
38+
k_scale = k / scales[:, None]
39+
k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
40+
41+
tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num)
42+
return
43+
44+
45+
@triton.jit
46+
def _fwd_kernel_destindex_copy_kv_per_tensor_fp8(
47+
K,
48+
Dest_loc,
49+
Out,
50+
scalar_scale,
51+
stride_k_bs,
52+
stride_k_h,
53+
stride_k_d,
54+
stride_o_bs,
55+
stride_o_h,
56+
stride_o_d,
57+
head_num,
58+
BLOCK_DMODEL: tl.constexpr,
59+
BLOCK_HEAD: tl.constexpr,
60+
FP8_MIN: tl.constexpr,
61+
FP8_MAX: tl.constexpr,
62+
):
63+
cur_index = tl.program_id(0)
64+
offs_h = tl.arange(0, BLOCK_HEAD)
65+
offs_d = tl.arange(0, BLOCK_DMODEL)
66+
67+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
68+
69+
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
70+
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
71+
72+
scale = tl.load(scalar_scale)
73+
74+
k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)
75+
k_scale = k / scale
76+
k_fp8 = tl.clamp(k_scale, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
77+
78+
tl.store(o_ptrs, k_fp8, mask=offs_h[:, None] < head_num)
79+
return
80+
81+
82+
@torch.no_grad()
83+
def destindex_copy_kv_fp8(K, DestLoc, scales, Out):
84+
if scales is None:
85+
Out[DestLoc] = K.to(torch.float8_e4m3fn)
86+
return
87+
88+
seq_len = DestLoc.shape[0]
89+
head_num = K.shape[1]
90+
head_dim = K.shape[2]
91+
assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]
92+
BLOCK_HEAD = triton.next_power_of_2(head_num)
93+
grid = (seq_len,)
94+
num_warps = 1
95+
96+
if scales.dim() == 0:
97+
_fwd_kernel_destindex_copy_kv_per_tensor_fp8[grid](
98+
K,
99+
DestLoc,
100+
Out,
101+
scales,
102+
K.stride(0),
103+
K.stride(1),
104+
K.stride(2),
105+
Out.stride(0),
106+
Out.stride(1),
107+
Out.stride(2),
108+
head_num,
109+
BLOCK_DMODEL=head_dim,
110+
BLOCK_HEAD=BLOCK_HEAD,
111+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
112+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
113+
num_warps=num_warps,
114+
num_stages=1,
115+
)
116+
else:
117+
_fwd_kernel_destindex_copy_kv_per_head_fp8[grid](
118+
K,
119+
DestLoc,
120+
Out,
121+
scales,
122+
K.stride(0),
123+
K.stride(1),
124+
K.stride(2),
125+
Out.stride(0),
126+
Out.stride(1),
127+
Out.stride(2),
128+
head_num,
129+
BLOCK_DMODEL=head_dim,
130+
BLOCK_HEAD=BLOCK_HEAD,
131+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
132+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
133+
num_warps=num_warps,
134+
num_stages=1,
135+
)
136+
return
137+
138+
139+
if __name__ == "__main__":
140+
import torch.nn.functional as F
141+
from lightllm.utils.vllm_utils import vllm_ops
142+
143+
B, N_CTX, H, HEAD_DIM = 32, 1024, 16, 128
144+
dtype = torch.bfloat16
145+
NUM = B
146+
dest_loc = torch.arange(NUM).cuda() * 2
147+
kv = torch.randn((len(dest_loc), H, HEAD_DIM), dtype=dtype).cuda()
148+
out = torch.zeros((B * N_CTX, H, HEAD_DIM), dtype=torch.uint8).cuda()
149+
scale = kv.abs().amax(dim=(0, 2)).to(torch.float32) / 448
150+
destindex_copy_kv_fp8(kv, dest_loc, scale, out.view(torch.float8_e4m3fn))
151+
152+
assert torch.allclose(
153+
out[:, :, :HEAD_DIM][dest_loc].view(torch.float8_e4m3fn).float() * scale.view(H, 1).expand(NUM, H, 1),
154+
kv.float(),
155+
atol=1e-5,
156+
rtol=1e-1,
157+
)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _per_head_max_reduce_kernel(
9+
Q,
10+
Scales,
11+
BatchIds,
12+
StartLoc,
13+
stride_q_t,
14+
stride_q_h,
15+
stride_scales_b,
16+
SET_BATCH_IDS: tl.constexpr,
17+
FP8_MAX: tl.constexpr,
18+
BLOCK_T: tl.constexpr,
19+
BLOCK_D: tl.constexpr,
20+
):
21+
b_id = tl.program_id(0)
22+
h_id = tl.program_id(1)
23+
24+
max_val = 0.0
25+
26+
start_loc = tl.load(StartLoc + b_id)
27+
end_loc = tl.load(StartLoc + b_id + 1)
28+
for t_offset in range(start_loc, end_loc, BLOCK_T):
29+
t_idx = t_offset + tl.arange(0, BLOCK_T)
30+
q_range = tl.arange(0, BLOCK_D)
31+
q_ptrs = Q + t_idx[:, None] * stride_q_t + h_id * stride_q_h + q_range[None, :]
32+
mask = (t_idx[:, None] < end_loc) & (q_range[None, :] < stride_q_h)
33+
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
34+
max_val = tl.maximum(tl.max(q_vals.abs()), max_val)
35+
if SET_BATCH_IDS:
36+
tl.store(BatchIds + t_idx, b_id, mask=t_idx < end_loc)
37+
38+
scale = tl.where(max_val > 0, max_val / FP8_MAX, 1.0)
39+
scale_ptr = Scales + b_id * stride_scales_b + h_id
40+
tl.store(scale_ptr, scale)
41+
42+
43+
@triton.jit
44+
def _apply_quantization_kernel(
45+
Q,
46+
Q_out,
47+
BatchIds,
48+
Scales,
49+
stride_q_t,
50+
stride_q_h,
51+
stride_qout_t,
52+
stride_qout_h,
53+
stride_scales_b,
54+
FP8_MIN: tl.constexpr,
55+
FP8_MAX: tl.constexpr,
56+
BLOCK_D: tl.constexpr,
57+
):
58+
t_id = tl.program_id(0)
59+
h_id = tl.program_id(1)
60+
61+
batch_id = tl.load(BatchIds + t_id)
62+
scale_ptr = Scales + batch_id * stride_scales_b + h_id
63+
scale = tl.load(scale_ptr)
64+
65+
q_range = tl.arange(0, BLOCK_D)
66+
q_ptrs = Q + t_id * stride_q_t + h_id * stride_q_h + q_range
67+
qout_ptrs = Q_out + t_id * stride_qout_t + h_id * stride_qout_h + q_range
68+
mask = q_range < stride_q_h
69+
q_vals = tl.load(q_ptrs, mask=mask, other=0.0)
70+
q_scaled = q_vals / scale
71+
q_clamped = tl.clamp(q_scaled, min=FP8_MIN, max=FP8_MAX).to(tl.float8e4nv)
72+
tl.store(qout_ptrs, q_clamped, mask=q_range < stride_qout_h)
73+
74+
75+
@torch.no_grad()
76+
def q_per_head_fp8_quant(q, seq_lens, b1_start_loc):
77+
T, H, D = q.shape
78+
B = seq_lens.shape[0]
79+
device = q.device
80+
81+
q_out = torch.empty_like(q, dtype=torch.float8_e4m3fn)
82+
scales = torch.empty((B, H), dtype=torch.float32, device=device)
83+
batch_ids = torch.zeros((T,), dtype=torch.int32, device=device)
84+
85+
BLOCK_D = triton.next_power_of_2(D)
86+
BLOCK_T = 256
87+
num_warps = 4
88+
num_stages = 2
89+
_per_head_max_reduce_kernel[(B, H)](
90+
q,
91+
scales,
92+
batch_ids,
93+
b1_start_loc,
94+
q.stride(0),
95+
q.stride(1),
96+
scales.stride(0),
97+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
98+
SET_BATCH_IDS=B > 1,
99+
BLOCK_T=BLOCK_T,
100+
BLOCK_D=BLOCK_D,
101+
num_warps=num_warps,
102+
num_stages=num_stages,
103+
)
104+
105+
_apply_quantization_kernel[(T, H)](
106+
q,
107+
q_out,
108+
batch_ids,
109+
scales,
110+
q.stride(0),
111+
q.stride(1),
112+
q_out.stride(0),
113+
q_out.stride(1),
114+
scales.stride(0),
115+
FP8_MIN=torch.finfo(torch.float8_e4m3fn).min,
116+
FP8_MAX=torch.finfo(torch.float8_e4m3fn).max,
117+
BLOCK_D=BLOCK_D,
118+
num_warps=num_warps,
119+
num_stages=num_stages,
120+
)
121+
return q_out, scales
122+
123+
124+
def ref_q_per_head_fp8_quant(q, seq_lens):
125+
min_fp8 = torch.finfo(torch.float8_e4m3fn).min
126+
max_fp8 = torch.finfo(torch.float8_e4m3fn).max
127+
B = seq_lens.size(0)
128+
device = q.device
129+
batch_ids = torch.repeat_interleave(torch.arange(B, device=device), seq_lens)
130+
max_per_time_head = q.abs().amax(dim=2)
131+
max_per_bh = torch.zeros((B, max_per_time_head.size(1)), device=device, dtype=max_per_time_head.dtype)
132+
max_per_bh.scatter_reduce_(
133+
0,
134+
batch_ids.unsqueeze(-1).expand(-1, max_per_time_head.size(1)),
135+
max_per_time_head,
136+
reduce="amax",
137+
include_self=False,
138+
)
139+
scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32)
140+
scale_expanded = scales[batch_ids].view(-1, scales.size(1), 1)
141+
q_q = (q / scale_expanded).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn)
142+
return q_q, scales
143+
144+
145+
if __name__ == "__main__":
146+
B, T, H, D = 200, 1000, 4, 7 * 128
147+
seq_lens = torch.ones((B,), dtype=torch.int32).cuda() * T // B
148+
start_locs = torch.zeros(B + 1, dtype=torch.int32).cuda()
149+
start_locs[1:] = seq_lens.cumsum(dim=0)
150+
q = torch.randn((T, H, D), dtype=torch.float32).cuda()
151+
152+
q_out, scales = q_per_head_fp8_quant(q, seq_lens, start_locs)
153+
q_out1, scales1 = ref_q_per_head_fp8_quant(q, seq_lens)
154+
assert torch.allclose(scales, scales1, atol=1e-10, rtol=0)
155+
assert torch.allclose(q_out.int(), q_out1.int(), atol=1e-10, rtol=0)

lightllm/common/fp8kv_mem_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
3+
from .mem_manager import MemoryManager
4+
5+
6+
class FP8KVMemoryManager(MemoryManager):
7+
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9):
8+
# 这里用uint8存储量化后的kv,方便兼容各种torch算子。fp8量化目前采用离线方案,kv_buffer不存储scale
9+
super().__init__(size, torch.uint8, head_num, head_dim, layer_num, always_copy, mem_fraction)

0 commit comments

Comments
 (0)