Skip to content

Commit e64026e

Browse files
jinzhen-linGWS0428
authored andcommitted
[Kernel] add triton fused moe kernel for gptq/awq (vllm-project#12185)
1 parent e8eada8 commit e64026e

File tree

4 files changed

+874
-55
lines changed

4 files changed

+874
-55
lines changed

tests/kernels/test_moe.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
fused_moe as iterative_moe)
1919
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
2020
marlin_quantize)
21+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
22+
quantize_weights)
2123
from vllm.model_executor.models.mixtral import MixtralMoE
2224
from vllm.platforms import current_platform
2325
from vllm.scalar_type import scalar_types
@@ -55,6 +57,95 @@ def test_fused_moe(
5557
rtol=0)
5658

5759

60+
@pytest.mark.parametrize("m", [1, 32, 222])
61+
@pytest.mark.parametrize("n", [128, 1024, 2048])
62+
@pytest.mark.parametrize("k", [128, 1024])
63+
@pytest.mark.parametrize("e", NUM_EXPERTS)
64+
@pytest.mark.parametrize("topk", TOP_KS)
65+
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
66+
@pytest.mark.parametrize("group_size", [64, 128])
67+
@pytest.mark.parametrize("has_zp", [True, False])
68+
@pytest.mark.parametrize("weight_bits", [4, 8])
69+
def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
70+
dtype: torch.dtype, group_size: int, has_zp: bool,
71+
weight_bits: int):
72+
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
73+
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
74+
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
75+
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
76+
score = torch.randn((m, e), device="cuda", dtype=dtype)
77+
78+
if weight_bits == 4:
79+
pack_factor = 2
80+
quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8
81+
elif weight_bits == 8:
82+
pack_factor = 1
83+
quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128
84+
85+
w1_ref = w1.clone()
86+
w2_ref = w2.clone()
87+
w1_qweight = torch.empty((e, 2 * n, k // pack_factor),
88+
device="cuda",
89+
dtype=torch.uint8)
90+
w2_qweight = torch.empty((e, k, n // pack_factor),
91+
device="cuda",
92+
dtype=torch.uint8)
93+
w1_scales = torch.empty((e, 2 * n, k // group_size),
94+
device="cuda",
95+
dtype=dtype)
96+
w2_scales = torch.empty((e, k, n // group_size),
97+
device="cuda",
98+
dtype=dtype)
99+
w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size),
100+
device="cuda",
101+
dtype=torch.uint8)
102+
w2_qzeros = torch.empty((e, k // pack_factor, n // group_size),
103+
device="cuda",
104+
dtype=torch.uint8)
105+
106+
for i in range(e * 2):
107+
expert_id = i % e
108+
if i // e == 0:
109+
w, w_ref, w_qweight, w_scales, w_qzeros = \
110+
w1, w1_ref, w1_qweight, w1_scales, w1_qzeros
111+
else:
112+
w, w_ref, w_qweight, w_scales, w_qzeros = \
113+
w2, w2_ref, w2_qweight, w2_scales, w2_qzeros
114+
weight, qweight, scales, qzeros = quantize_weights(
115+
w[expert_id].T, quant_type, group_size, has_zp, False)
116+
weight = weight.T
117+
qweight = qweight.T.contiguous().to(torch.uint8)
118+
scales = scales.T
119+
if has_zp:
120+
qzeros = qzeros.T.contiguous().to(torch.uint8)
121+
if weight_bits == 4:
122+
qweight = qweight[:, 1::2] * 16 + qweight[:, ::2]
123+
if has_zp:
124+
qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :]
125+
126+
w_ref[expert_id] = weight
127+
w_qweight[expert_id] = qweight
128+
w_scales[expert_id] = scales
129+
if has_zp:
130+
w_qzeros[expert_id] = qzeros
131+
132+
triton_output = fused_moe(a,
133+
w1_qweight,
134+
w2_qweight,
135+
score,
136+
topk,
137+
renormalize=False,
138+
use_int4_w4a16=weight_bits == 4,
139+
use_int8_w8a16=weight_bits == 8,
140+
w1_scale=w1_scales,
141+
w2_scale=w2_scales,
142+
w1_zp=w1_qzeros if has_zp else None,
143+
w2_zp=w2_qzeros if has_zp else None,
144+
block_shape=[0, group_size])
145+
torch_output = torch_moe(a, w1_ref, w2_ref, score, topk)
146+
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
147+
148+
58149
@pytest.mark.parametrize("dtype",
59150
[torch.float32, torch.float16, torch.bfloat16])
60151
@torch.inference_mode()

0 commit comments

Comments
 (0)