Skip to content

Commit e540520

Browse files
committed
0609
1 parent 6a4a100 commit e540520

File tree

12 files changed

+170
-82
lines changed

12 files changed

+170
-82
lines changed

lightllm-kernel/lightllm_kernel/ops/fusion.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,28 @@
22
from typing import Optional, Tuple
33
from . import _C
44

5+
56
def pre_tp_norm_bf16(input: torch.Tensor) -> torch.Tensor:
6-
""" Calculate powersum along embedding dimension of the input """
7+
"""Calculate powersum along embedding dimension of the input"""
78
return _C.pre_tp_norm_bf16(input)
89

9-
def post_tp_norm_bf16(input: torch.tensor, weight: torch.Tensor, tp_variance: torch.Tensor, embed_dim: int, eps: float) -> torch.Tensor:
10-
""" Apply rmsnorm on given input, with weight and pre calculated powersum """
10+
11+
def post_tp_norm_bf16(
12+
input: torch.tensor, weight: torch.Tensor, tp_variance: torch.Tensor, embed_dim: int, eps: float
13+
) -> torch.Tensor:
14+
"""Apply rmsnorm on given input, with weight and pre calculated powersum"""
1115
return _C.post_tp_norm_bf16(input, weight, tp_variance, embed_dim, eps)
1216

13-
def add_norm_quant_bf16_fp8(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float) -> Tuple[torch.Tensor, torch.Tensor]:
14-
""" Apply add_norm_quant on given input, with residual and weight """
17+
18+
def add_norm_quant_bf16_fp8(
19+
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float
20+
) -> Tuple[torch.Tensor, torch.Tensor]:
21+
"""Apply add_norm_quant on given input, with residual and weight"""
1522
return _C.add_norm_quant_bf16_fp8(input, residual, weight, eps)
1623

24+
1725
def gelu_per_token_quant_bf16_fp8(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
18-
""" Apply gelu on given input and quantize it from bf16 to fp8 using per token quant method """
26+
"""Apply gelu on given input and quantize it from bf16 to fp8 using per token quant method"""
1927
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
2028
scales = torch.empty(size=(input.shape[0], 1), device=input.device, dtype=torch.float32)
2129
_C.gelu_per_token_quant_bf16_fp8(output, input, scales)

lightllm-kernel/lightllm_kernel/ops/gemm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
from typing import Optional
33
from . import _C
44

5-
def cutlass_scaled_mm_bias_ls(c: torch.Tensor, a: torch.Tensor, b: torch.Tensor,
6-
a_scales: torch.Tensor, b_scales: torch.Tensor, bias: Optional[torch.Tensor], ls: Optional[torch.Tensor]) -> None :
7-
""" Apply scaled mm on the given input, with optional bias and ls weight """
5+
6+
def cutlass_scaled_mm_bias_ls(
7+
c: torch.Tensor,
8+
a: torch.Tensor,
9+
b: torch.Tensor,
10+
a_scales: torch.Tensor,
11+
b_scales: torch.Tensor,
12+
bias: Optional[torch.Tensor],
13+
ls: Optional[torch.Tensor],
14+
) -> None:
15+
"""Apply scaled mm on the given input, with optional bias and ls weight"""
816
return _C.cutlass_scaled_mm(c, a, b, a_scales, b_scales, bias, ls)

lightllm-kernel/lightllm_kernel/ops/norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22
from typing import Optional
33
from . import _C
44

5-
def rmsnorm_bf16(X: torch.Tensor, W: torch.Tensor, eps: float=1e-12) -> torch.Tensor:
6-
""" Apply rmsnorm on given X, with weight W and eps """
5+
6+
def rmsnorm_bf16(X: torch.Tensor, W: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
77
return _C.rmsnorm_align16_bf16(X, W, eps)

lightllm-kernel/lightllm_kernel/ops/quant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
from typing import Optional, Tuple
33
from . import _C
44

5+
56
def per_token_quant_bf16_fp8(input: torch.tensor) -> Tuple[torch.Tensor, torch.Tensor]:
6-
""" Quantize the given input using per token quant method """
7+
"""Quantize the given input using per token quant method"""
78
output = torch.empty_like(input, dtype=torch.float8_e4m3fn)
89
scales = torch.empty(size=(input.shape[0], 1), device=input.device, dtype=torch.float32)
910
_C.per_token_quant_bf16_fp8(output, input, scales)

lightllm-kernel/test/fusion/add_norm_quant_test.py

Lines changed: 54 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ def torch_add_norm_quant_bf16_fp8(X, R, W, eps=1e-6):
1010
# 1. Add residual
1111
X = X.add_(R)
1212
# 2. rmsnorm
13-
normalized = torch.nn.functional.rms_norm(X, (N, ), W, eps=eps)
13+
normalized = torch.nn.functional.rms_norm(X, (N,), W, eps=eps)
1414
# 3. per token quant
1515
quantized, scales = ops.scaled_fp8_quant(normalized, scale=None, use_per_token_if_dynamic=True)
1616

1717
return quantized, scales
1818

19+
1920
class TestFusedAddNormQuantBF16(unittest.TestCase):
2021
def setUp(self):
2122
"""Set up common test parameters."""
@@ -31,40 +32,65 @@ def test_accuracy(self):
3132
for batch in self.batchs:
3233
for seqLen in self.seqLens:
3334
for embed_dim in self.embed_dims:
34-
with self.subTest(shape=[batch, seqLen, embed_dim]):
35-
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
36-
X2 = X1.clone()
37-
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
38-
R2 = R1.clone()
39-
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
40-
output_real, scales_real = torch_add_norm_quant_bf16_fp8(X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps)
41-
output_pred, scales_pred = add_norm_quant_bf16_fp8(X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps)
35+
with self.subTest(shape=[batch, seqLen, embed_dim]):
36+
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
37+
X2 = X1.clone()
38+
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
39+
R2 = R1.clone()
40+
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
41+
output_real, scales_real = torch_add_norm_quant_bf16_fp8(
42+
X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps
43+
)
44+
output_pred, scales_pred = add_norm_quant_bf16_fp8(
45+
X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps
46+
)
4247

43-
self.assertTrue(
44-
error(output_real, output_pred) < 0.01,
45-
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. output_real={output_real}, output_pred={output_pred}"
46-
)
47-
self.assertTrue(
48-
error(scales_real, scales_pred) < 0.01,
49-
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. scales_real={scales_real}, scales_pred={scales_pred}"
50-
)
48+
self.assertTrue(
49+
error(output_real, output_pred) < 0.01,
50+
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. "
51+
f"output_real={output_real}, output_pred={output_pred}",
52+
)
53+
self.assertTrue(
54+
error(scales_real, scales_pred) < 0.01,
55+
f"Accuracy test failed for size {batch}, {seqLen}, {embed_dim}. "
56+
f"scales_real={scales_real}, scales_pred={scales_pred}",
57+
)
5158

5259
def test_performance(self):
5360
"""Test the performance of FusedAddNormQuant using benchmark."""
5461
for batch in self.batchs:
5562
for seqLen in self.seqLens:
5663
for embed_dim in self.embed_dims:
57-
with self.subTest(shape=[batch, seqLen, embed_dim]):
58-
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
59-
X2 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
60-
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
61-
R2 = R1.clone()
62-
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
64+
with self.subTest(shape=[batch, seqLen, embed_dim]):
65+
X1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
66+
X2 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
67+
R1 = torch.rand(size=[batch, seqLen, embed_dim], device=self.device, dtype=self.dtype) - 0.5
68+
R2 = R1.clone()
69+
W = torch.rand(size=[embed_dim], device=self.device, dtype=self.dtype) - 0.5
70+
71+
shape = [[batch, seqLen, embed_dim]]
72+
tflops = 0.0
73+
benchmark(
74+
torch_add_norm_quant_bf16_fp8,
75+
shape,
76+
tflops,
77+
100,
78+
X1.reshape(-1, X1.shape[2]),
79+
R1.reshape(-1, R1.shape[2]),
80+
W,
81+
self.eps,
82+
)
83+
benchmark(
84+
add_norm_quant_bf16_fp8,
85+
shape,
86+
tflops,
87+
100,
88+
X2.reshape(-1, X1.shape[2]),
89+
R2.reshape(-1, R2.shape[2]),
90+
W,
91+
self.eps,
92+
)
6393

64-
shape = [[batch, seqLen, embed_dim]]
65-
tflops = 0.0
66-
benchmark(torch_add_norm_quant_bf16_fp8, shape, tflops, 100, X1.reshape(-1, X1.shape[2]), R1.reshape(-1, R1.shape[2]), W, self.eps)
67-
benchmark(add_norm_quant_bf16_fp8, shape, tflops, 100, X2.reshape(-1, X1.shape[2]), R2.reshape(-1, R2.shape[2]), W, self.eps)
6894

6995
if __name__ == "__main__":
70-
unittest.main()
96+
unittest.main()

lightllm-kernel/test/fusion/gelu_per_token_quant_test.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
from lightllm_kernel.ops import per_token_quant_bf16_fp8, gelu_per_token_quant_bf16_fp8
55
from test.utils import benchmark, error
66

7+
78
def gelu_quant(x):
89
y = gelu_fwd(x)
910
return per_token_quant_bf16_fp8(y)
1011

12+
1113
class TestGeluQuantBF16(unittest.TestCase):
1214
def setUp(self):
1315
"""Set up common test parameters."""
@@ -21,20 +23,23 @@ def test_accuracy(self):
2123
for token in self.tokens:
2224
for hiddenDim in self.hiddenDims:
2325
with self.subTest(shape=[token, hiddenDim]):
24-
input = torch.normal(mean=0.0, std=10, size=[token, hiddenDim], device=self.device, dtype=self.dtype)
26+
input = torch.normal(
27+
mean=0.0, std=10, size=[token, hiddenDim], device=self.device, dtype=self.dtype
28+
)
2529

2630
y_real, scales_real = gelu_quant(input)
2731
y_pred, scales_pred = gelu_per_token_quant_bf16_fp8(input)
28-
32+
2933
self.assertTrue(
3034
error(scales_real, scales_pred) < 0.01,
31-
f"Accuracy test failed for size {token}, {hiddenDim}. scales_real={scales_real}, scales_pred={scales_pred}"
35+
f"Accuracy test failed for size {token}, {hiddenDim}. "
36+
f"scales_real={scales_real}, scales_pred={scales_pred}",
3237
)
3338
self.assertTrue(
3439
error(y_real, y_pred) < 0.01,
35-
f"Accuracy test failed for size {token}, {hiddenDim}. y_real={y_real}, y_pred={y_pred}"
40+
f"Accuracy test failed for size {token}, {hiddenDim}." f"y_real={y_real}, y_pred={y_pred}",
3641
)
37-
42+
3843
def test_performance(self):
3944
"""Test the performance of gelu_per_token_quant using benchmark."""
4045
for token in self.tokens:
@@ -46,5 +51,6 @@ def test_performance(self):
4651
benchmark(gelu_per_token_quant_bf16_fp8, shape, tflops, 100, input)
4752
benchmark(gelu_quant, shape, tflops, 100, input)
4853

54+
4955
if __name__ == "__main__":
50-
unittest.main()
56+
unittest.main()

lightllm-kernel/test/fusion/post_tp_norm_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def post_tp_norm(input, weight, tp_variance, embed_dim, eps):
1212
out = weight * input.to(torch.bfloat16)
1313
return out
1414

15+
1516
class TestPostTpNormBF16(unittest.TestCase):
1617
def setUp(self):
1718
"""Set up common test parameters."""
@@ -34,7 +35,7 @@ def test_accuracy(self):
3435
y_pred = post_tp_norm_bf16(X, W, V, self.embed_dim, self.eps)
3536
self.assertTrue(
3637
error(y_pred, y_real) < 0.01,
37-
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}"
38+
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}",
3839
)
3940

4041
def test_performance(self):
@@ -50,5 +51,6 @@ def test_performance(self):
5051
benchmark(post_tp_norm_bf16, shape, tflops, 100, X, W, V, self.embed_dim, self.eps)
5152
benchmark(post_tp_norm, shape, tflops, 100, X, W, V, self.embed_dim, self.eps)
5253

54+
5355
if __name__ == "__main__":
54-
unittest.main()
56+
unittest.main()

lightllm-kernel/test/fusion/pre_tp_norm_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ def pre_tp_norm(input):
99
tp_variance = input.pow(2).sum(-1, keepdim=False)
1010
return tp_variance
1111

12+
1213
class TestPreTpNormBF16(unittest.TestCase):
1314
def setUp(self):
1415
"""Set up common test parameters."""
@@ -27,20 +28,21 @@ def test_accuracy(self):
2728
y_pred = pre_tp_norm_bf16(X)
2829
self.assertTrue(
2930
error(y_pred, y_real) < 0.01,
30-
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}"
31+
f"Accuracy test failed for size {batch}, {size}. y_real={y_real}, y_pred={y_pred}",
3132
)
3233

3334
def test_performance(self):
3435
for batch in self.batchs:
3536
for size in self.sizes:
3637
with self.subTest(shape=[batch, size]):
3738
X = torch.rand(size=[batch, size], device=self.device, dtype=self.dtype) - 0.5
38-
W = torch.rand(size=[size], device=self.device, dtype=self.dtype) - 0.5
39+
# W = torch.rand(size=[size], device=self.device, dtype=self.dtype) - 0.5
3940

4041
shape = [[batch, size], [size], [batch, size]]
4142
tflops = 0.0
4243
benchmark(pre_tp_norm_bf16, shape, tflops, 100, X)
4344
benchmark(pre_tp_norm, shape, tflops, 100, X)
4445

46+
4547
if __name__ == "__main__":
46-
unittest.main()
48+
unittest.main()

lightllm-kernel/test/gemm/cutlass_scaled_mm_test.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ def torch_cutlass_scale_gemm_with_ls(x_q, w_q_t, x_scale, w_scale, out_dtype=tor
1010
y_pred = y_pred_tmp * ls
1111
return y_pred
1212

13+
1314
class TestQuantBF16(unittest.TestCase):
1415
def setUp(self):
1516
"""Set up common test parameters."""
@@ -18,7 +19,6 @@ def setUp(self):
1819
self.device = "cuda"
1920
self.dtype = torch.bfloat16
2021

21-
2222
def test_accuracy(self):
2323
"""Test the accuracy of cutlass_scaled_mm_bias_ls"""
2424
for token in self.tokens:
@@ -29,10 +29,11 @@ def test_accuracy(self):
2929
input = torch.randn(size=[M, K], device=self.device, dtype=self.dtype)
3030
x_q, x_scale = ops.scaled_fp8_quant(input, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
3131

32-
3332
# 生成权重张量w_q(N×K),转置后为K×N(列优先)
3433
weight = torch.randn(size=[N, K], device=self.device, dtype=self.dtype)
35-
w_q, w_scale = ops.scaled_fp8_quant(weight, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
34+
w_q, w_scale = ops.scaled_fp8_quant(
35+
weight, scale=None, scale_ub=None, use_per_token_if_dynamic=True
36+
)
3637

3738
# 转置,w_q_t为列优先
3839
w_q_t = w_q.t()
@@ -43,11 +44,13 @@ def test_accuracy(self):
4344
ls = torch.randn(size=[N], device=self.device, dtype=torch.bfloat16)
4445

4546
cutlass_scaled_mm_bias_ls(y_pred, x_q, w_q_t, x_scale, w_scale, bias=bias, ls=ls)
46-
y_real = torch_cutlass_scale_gemm_with_ls(x_q, w_q_t, x_scale, w_scale, out_dtype=torch.bfloat16, bias=bias, ls=ls)
47+
y_real = torch_cutlass_scale_gemm_with_ls(
48+
x_q, w_q_t, x_scale, w_scale, out_dtype=torch.bfloat16, bias=bias, ls=ls
49+
)
4750

4851
self.assertTrue(
4952
error(y_pred, y_real) < 0.01,
50-
f"Accuracy test failed for size {token}, {hiddenDim}. y_pred={y_pred}, y_real={y_real}"
53+
f"Accuracy test failed for size {token}, {hiddenDim}. y_pred={y_pred}, y_real={y_real}",
5154
)
5255

5356
def test_performance(self):
@@ -62,7 +65,9 @@ def test_performance(self):
6265

6366
# 生成权重张量w_q(N×K),转置后为K×N(列优先)
6467
weight = torch.randn(size=[N, K], device=self.device, dtype=self.dtype) - 0.5
65-
w_q, w_scale = ops.scaled_fp8_quant(weight, scale=None, scale_ub=None, use_per_token_if_dynamic=True)
68+
w_q, w_scale = ops.scaled_fp8_quant(
69+
weight, scale=None, scale_ub=None, use_per_token_if_dynamic=True
70+
)
6671

6772
bias = torch.randn(size=[N], device=self.device, dtype=torch.bfloat16)
6873
ls = torch.randn(size=[N], device=self.device, dtype=torch.bfloat16)
@@ -72,9 +77,34 @@ def test_performance(self):
7277

7378
y_pred = torch.empty((M, N), dtype=input.dtype, device=input.device)
7479
shape = [[token, hiddenDim]]
75-
tflops = 2 * token * (3 * hiddenDim) * hiddenDim / 1024**4
76-
benchmark(cutlass_scaled_mm_bias_ls, shape, tflops, 100, y_pred, x_q, w_q_t, x_scale, w_scale, bias=bias, ls=ls)
77-
benchmark(torch_cutlass_scale_gemm_with_ls, shape, tflops, 100, x_q, w_q_t, x_scale, w_scale, out_dtype=torch.bfloat16, bias=bias, ls=ls) # 无bias 495GB/s, 有bias 482GB/s
80+
tflops = 2 * token * (3 * hiddenDim) * hiddenDim / 1024 ** 4
81+
benchmark(
82+
cutlass_scaled_mm_bias_ls,
83+
shape,
84+
tflops,
85+
100,
86+
y_pred,
87+
x_q,
88+
w_q_t,
89+
x_scale,
90+
w_scale,
91+
bias=bias,
92+
ls=ls,
93+
)
94+
benchmark(
95+
torch_cutlass_scale_gemm_with_ls,
96+
shape,
97+
tflops,
98+
100,
99+
x_q,
100+
w_q_t,
101+
x_scale,
102+
w_scale,
103+
out_dtype=torch.bfloat16,
104+
bias=bias,
105+
ls=ls,
106+
) # 无bias 495GB/s, 有bias 482GB/s
107+
78108

79109
if __name__ == "__main__":
80-
unittest.main()
110+
unittest.main()

0 commit comments

Comments
 (0)