|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +"""Tests for QuantFP8 Group Quantization implementation.""" |
| 4 | + |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | + |
| 8 | +from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 |
| 9 | +from vllm.model_executor.layers.quantization.utils.quant_utils import ( |
| 10 | + GroupShape) |
| 11 | +from vllm.platforms import current_platform |
| 12 | + |
| 13 | + |
| 14 | +@pytest.mark.parametrize("batch_size", [16, 32]) |
| 15 | +@pytest.mark.parametrize("hidden_dim", |
| 16 | + [256, 512, 513]) # Include non-divisible |
| 17 | +@pytest.mark.parametrize("group_size", [32, 64, 128]) |
| 18 | +@pytest.mark.parametrize("seed", [42]) |
| 19 | +@torch.inference_mode() |
| 20 | +def test_quantfp8_group_basic(batch_size: int, hidden_dim: int, |
| 21 | + group_size: int, seed: int) -> None: |
| 22 | + current_platform.seed_everything(seed) |
| 23 | + |
| 24 | + x = torch.randn( |
| 25 | + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 |
| 26 | + |
| 27 | + # Create QuantFP8 with group quantization |
| 28 | + group_shape = GroupShape(1, group_size) |
| 29 | + quant_op = QuantFP8(static=False, |
| 30 | + group_shape=group_shape, |
| 31 | + column_major_scales=False) |
| 32 | + |
| 33 | + expected_num_groups = (hidden_dim + group_size - 1) // group_size |
| 34 | + |
| 35 | + # Test CUDA implementation (only supports divisible dimensions) |
| 36 | + if hidden_dim % group_size == 0: |
| 37 | + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) |
| 38 | + assert x_quant_cuda.shape == x.shape |
| 39 | + assert scales_cuda.shape == (batch_size, expected_num_groups) |
| 40 | + |
| 41 | + # Test PyTorch native implementation |
| 42 | + x_quant_native, scales_native = quant_op.forward_native(x.clone()) |
| 43 | + assert x_quant_native.shape == x.shape |
| 44 | + assert scales_native.shape == (batch_size, expected_num_groups) |
| 45 | + |
| 46 | + # Test column_major_scales |
| 47 | + quant_op_col = QuantFP8(static=False, |
| 48 | + group_shape=group_shape, |
| 49 | + column_major_scales=True) |
| 50 | + _, scales_col = quant_op_col.forward_native(x.clone()) |
| 51 | + assert scales_col.shape == (expected_num_groups, batch_size) |
| 52 | + |
| 53 | + |
| 54 | +@pytest.mark.parametrize("seed", [42]) |
| 55 | +@torch.inference_mode() |
| 56 | +def test_quantfp8_group_multidimensional(seed: int) -> None: |
| 57 | + current_platform.seed_everything(seed) |
| 58 | + |
| 59 | + group_size = 64 |
| 60 | + |
| 61 | + # Test with 3D input |
| 62 | + batch1, batch2, hidden_dim = 4, 8, 512 |
| 63 | + x_3d = torch.randn( |
| 64 | + (batch1, batch2, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 |
| 65 | + |
| 66 | + group_shape = GroupShape(1, group_size) |
| 67 | + quant_op = QuantFP8(static=False, |
| 68 | + group_shape=group_shape, |
| 69 | + column_major_scales=False) |
| 70 | + |
| 71 | + x_quant, scales = quant_op.forward_native(x_3d.clone()) |
| 72 | + assert x_quant.shape == x_3d.shape |
| 73 | + assert scales.shape == (batch1, batch2, hidden_dim // group_size) |
| 74 | + |
| 75 | + # Test column_major_scales with multi-dim |
| 76 | + quant_op_col = QuantFP8(static=False, |
| 77 | + group_shape=group_shape, |
| 78 | + column_major_scales=True) |
| 79 | + _, scales_col = quant_op_col.forward_native(x_3d.clone()) |
| 80 | + assert scales_col.shape == (batch1, hidden_dim // group_size, batch2) |
| 81 | + |
| 82 | + # Test with 4D input |
| 83 | + batch1, batch2, batch3, hidden_dim = 2, 3, 4, 256 |
| 84 | + x_4d = torch.randn((batch1, batch2, batch3, hidden_dim), |
| 85 | + dtype=torch.bfloat16, |
| 86 | + device="cuda") * 8 |
| 87 | + |
| 88 | + x_quant_4d, scales_4d = quant_op.forward_native(x_4d.clone()) |
| 89 | + assert x_quant_4d.shape == x_4d.shape |
| 90 | + assert scales_4d.shape == (batch1, batch2, batch3, |
| 91 | + hidden_dim // group_size) |
| 92 | + |
| 93 | + _, scales_4d_col = quant_op_col.forward_native(x_4d.clone()) |
| 94 | + assert scales_4d_col.shape == (batch1, batch2, hidden_dim // group_size, |
| 95 | + batch3) |
| 96 | + |
| 97 | + |
| 98 | +@pytest.mark.parametrize("batch_size", [32]) |
| 99 | +@pytest.mark.parametrize("hidden_dim", [1024]) |
| 100 | +@pytest.mark.parametrize("group_size", [128]) |
| 101 | +@pytest.mark.parametrize("seed", [42]) |
| 102 | +@torch.inference_mode() |
| 103 | +def test_quantfp8_group_cuda_native_consistency(batch_size: int, |
| 104 | + hidden_dim: int, |
| 105 | + group_size: int, |
| 106 | + seed: int) -> None: |
| 107 | + """Compare CUDA and native implementations for consistency.""" |
| 108 | + current_platform.seed_everything(seed) |
| 109 | + |
| 110 | + x = torch.randn( |
| 111 | + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 |
| 112 | + |
| 113 | + group_shape = GroupShape(1, group_size) |
| 114 | + quant_op = QuantFP8(static=False, |
| 115 | + group_shape=group_shape, |
| 116 | + column_major_scales=False) |
| 117 | + |
| 118 | + # Run both implementations |
| 119 | + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) |
| 120 | + x_quant_native, scales_native = quant_op.forward_native(x.clone()) |
| 121 | + |
| 122 | + # Check shapes match |
| 123 | + assert x_quant_cuda.shape == x_quant_native.shape |
| 124 | + assert scales_cuda.shape == scales_native.shape |
| 125 | + |
| 126 | + # Scales should match |
| 127 | + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) |
| 128 | + |
| 129 | + # Quantized values should mostly match, with rare rounding differences |
| 130 | + # FP8 rounding at boundaries can differ between CUDA and PyTorch |
| 131 | + diff_count = (x_quant_cuda != x_quant_native).sum().item() |
| 132 | + diff_ratio = diff_count / x_quant_cuda.numel() |
| 133 | + assert diff_ratio < 0.002, f"Too many differences: {diff_ratio:.4%}" |
| 134 | + |
| 135 | + |
| 136 | +@pytest.mark.parametrize("seed", [42]) |
| 137 | +@torch.inference_mode() |
| 138 | +def test_quantfp8_group_edge_cases(seed: int) -> None: |
| 139 | + current_platform.seed_everything(seed) |
| 140 | + |
| 141 | + batch_size = 16 |
| 142 | + group_size = 64 |
| 143 | + |
| 144 | + # Test with single group (group_size >= hidden_dim) |
| 145 | + x_small = torch.randn( |
| 146 | + (batch_size, 32), dtype=torch.bfloat16, device="cuda") * 8 |
| 147 | + group_shape = GroupShape(1, group_size) |
| 148 | + quant_op = QuantFP8(static=False, |
| 149 | + group_shape=group_shape, |
| 150 | + column_major_scales=False) |
| 151 | + |
| 152 | + x_quant_small, scales_small = quant_op.forward_native(x_small.clone()) |
| 153 | + assert x_quant_small.shape == x_small.shape |
| 154 | + assert scales_small.shape == (batch_size, 1) |
| 155 | + |
| 156 | + # Test with zero inputs |
| 157 | + x_zero = torch.zeros((batch_size, 256), |
| 158 | + dtype=torch.bfloat16, |
| 159 | + device="cuda") |
| 160 | + x_quant_zero, scales_zero = quant_op.forward_native(x_zero.clone()) |
| 161 | + assert x_quant_zero.shape == x_zero.shape |
| 162 | + assert (scales_zero > 0).all(), "Scales should be clamped to minimum" |
| 163 | + |
| 164 | + # Test very large values |
| 165 | + x_large = torch.full((batch_size, 256), |
| 166 | + 1000.0, |
| 167 | + dtype=torch.bfloat16, |
| 168 | + device="cuda") |
| 169 | + x_quant_large, scales_large = quant_op.forward_native(x_large.clone()) |
| 170 | + assert x_quant_large.shape == x_large.shape |
| 171 | + # FP8 max is typically 448 or 224, so scales should be > 1 |
| 172 | + assert (scales_large > 1.0).all(), "Large values should have scales > 1" |
| 173 | + |
| 174 | + |
| 175 | +@pytest.mark.parametrize( |
| 176 | + "batch_size,hidden_dim,group_size", |
| 177 | + [ |
| 178 | + (16, 256, 16), # Small |
| 179 | + (64, 1024, 64), # Medium |
| 180 | + (128, 2048, 128), # Large |
| 181 | + (8, 513, 64), # Non-divisible (native only) |
| 182 | + ]) |
| 183 | +@pytest.mark.parametrize("seed", [42]) |
| 184 | +@torch.inference_mode() |
| 185 | +def test_quantfp8_group_various_configs(batch_size: int, hidden_dim: int, |
| 186 | + group_size: int, seed: int) -> None: |
| 187 | + current_platform.seed_everything(seed) |
| 188 | + |
| 189 | + x = torch.randn( |
| 190 | + (batch_size, hidden_dim), dtype=torch.bfloat16, device="cuda") * 8 |
| 191 | + group_shape = GroupShape(1, group_size) |
| 192 | + quant_op = QuantFP8(static=False, |
| 193 | + group_shape=group_shape, |
| 194 | + column_major_scales=False) |
| 195 | + |
| 196 | + expected_num_groups = (hidden_dim + group_size - 1) // group_size |
| 197 | + |
| 198 | + x_quant_native, scales_native = quant_op.forward_native(x.clone()) |
| 199 | + assert x_quant_native.shape == x.shape |
| 200 | + assert scales_native.shape == (batch_size, expected_num_groups) |
| 201 | + |
| 202 | + if hidden_dim % group_size == 0: |
| 203 | + x_quant_cuda, scales_cuda = quant_op.forward_cuda(x.clone()) |
| 204 | + assert x_quant_cuda.shape == x.shape |
| 205 | + assert scales_cuda.shape == (batch_size, expected_num_groups) |
| 206 | + assert torch.allclose(scales_cuda, scales_native, rtol=1e-9, atol=1e-8) |
0 commit comments