|
18 | 18 | fused_moe as iterative_moe)
|
19 | 19 | from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
20 | 20 | marlin_quantize)
|
| 21 | +from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
| 22 | + quantize_weights) |
21 | 23 | from vllm.model_executor.models.mixtral import MixtralMoE
|
22 | 24 | from vllm.platforms import current_platform
|
23 | 25 | from vllm.scalar_type import scalar_types
|
@@ -55,6 +57,95 @@ def test_fused_moe(
|
55 | 57 | rtol=0)
|
56 | 58 |
|
57 | 59 |
|
| 60 | +@pytest.mark.parametrize("m", [1, 32, 222]) |
| 61 | +@pytest.mark.parametrize("n", [128, 1024, 2048]) |
| 62 | +@pytest.mark.parametrize("k", [128, 1024]) |
| 63 | +@pytest.mark.parametrize("e", NUM_EXPERTS) |
| 64 | +@pytest.mark.parametrize("topk", TOP_KS) |
| 65 | +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) |
| 66 | +@pytest.mark.parametrize("group_size", [64, 128]) |
| 67 | +@pytest.mark.parametrize("has_zp", [True, False]) |
| 68 | +@pytest.mark.parametrize("weight_bits", [4, 8]) |
| 69 | +def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, |
| 70 | + dtype: torch.dtype, group_size: int, has_zp: bool, |
| 71 | + weight_bits: int): |
| 72 | + print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) |
| 73 | + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 |
| 74 | + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 |
| 75 | + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 |
| 76 | + score = torch.randn((m, e), device="cuda", dtype=dtype) |
| 77 | + |
| 78 | + if weight_bits == 4: |
| 79 | + pack_factor = 2 |
| 80 | + quant_type = scalar_types.uint4 if has_zp else scalar_types.uint4b8 |
| 81 | + elif weight_bits == 8: |
| 82 | + pack_factor = 1 |
| 83 | + quant_type = scalar_types.uint8 if has_zp else scalar_types.uint8b128 |
| 84 | + |
| 85 | + w1_ref = w1.clone() |
| 86 | + w2_ref = w2.clone() |
| 87 | + w1_qweight = torch.empty((e, 2 * n, k // pack_factor), |
| 88 | + device="cuda", |
| 89 | + dtype=torch.uint8) |
| 90 | + w2_qweight = torch.empty((e, k, n // pack_factor), |
| 91 | + device="cuda", |
| 92 | + dtype=torch.uint8) |
| 93 | + w1_scales = torch.empty((e, 2 * n, k // group_size), |
| 94 | + device="cuda", |
| 95 | + dtype=dtype) |
| 96 | + w2_scales = torch.empty((e, k, n // group_size), |
| 97 | + device="cuda", |
| 98 | + dtype=dtype) |
| 99 | + w1_qzeros = torch.empty((e, 2 * n // pack_factor, k // group_size), |
| 100 | + device="cuda", |
| 101 | + dtype=torch.uint8) |
| 102 | + w2_qzeros = torch.empty((e, k // pack_factor, n // group_size), |
| 103 | + device="cuda", |
| 104 | + dtype=torch.uint8) |
| 105 | + |
| 106 | + for i in range(e * 2): |
| 107 | + expert_id = i % e |
| 108 | + if i // e == 0: |
| 109 | + w, w_ref, w_qweight, w_scales, w_qzeros = \ |
| 110 | + w1, w1_ref, w1_qweight, w1_scales, w1_qzeros |
| 111 | + else: |
| 112 | + w, w_ref, w_qweight, w_scales, w_qzeros = \ |
| 113 | + w2, w2_ref, w2_qweight, w2_scales, w2_qzeros |
| 114 | + weight, qweight, scales, qzeros = quantize_weights( |
| 115 | + w[expert_id].T, quant_type, group_size, has_zp, False) |
| 116 | + weight = weight.T |
| 117 | + qweight = qweight.T.contiguous().to(torch.uint8) |
| 118 | + scales = scales.T |
| 119 | + if has_zp: |
| 120 | + qzeros = qzeros.T.contiguous().to(torch.uint8) |
| 121 | + if weight_bits == 4: |
| 122 | + qweight = qweight[:, 1::2] * 16 + qweight[:, ::2] |
| 123 | + if has_zp: |
| 124 | + qzeros = qzeros[1::2, :] * 16 + qzeros[::2, :] |
| 125 | + |
| 126 | + w_ref[expert_id] = weight |
| 127 | + w_qweight[expert_id] = qweight |
| 128 | + w_scales[expert_id] = scales |
| 129 | + if has_zp: |
| 130 | + w_qzeros[expert_id] = qzeros |
| 131 | + |
| 132 | + triton_output = fused_moe(a, |
| 133 | + w1_qweight, |
| 134 | + w2_qweight, |
| 135 | + score, |
| 136 | + topk, |
| 137 | + renormalize=False, |
| 138 | + use_int4_w4a16=weight_bits == 4, |
| 139 | + use_int8_w8a16=weight_bits == 8, |
| 140 | + w1_scale=w1_scales, |
| 141 | + w2_scale=w2_scales, |
| 142 | + w1_zp=w1_qzeros if has_zp else None, |
| 143 | + w2_zp=w2_qzeros if has_zp else None, |
| 144 | + block_shape=[0, group_size]) |
| 145 | + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk) |
| 146 | + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) |
| 147 | + |
| 148 | + |
58 | 149 | @pytest.mark.parametrize("dtype",
|
59 | 150 | [torch.float32, torch.float16, torch.bfloat16])
|
60 | 151 | @torch.inference_mode()
|
|
0 commit comments