Skip to content

Commit e8bc353

Browse files
authored
[1/N] Initial vllm-ext evaluation support (MXFP4 MOE) (#935)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 282aab6 commit e8bc353

File tree

12 files changed

+2682
-0
lines changed

12 files changed

+2682
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# ==---------------------------------------------------------------------------==
16+
# Apply the extension
17+
# ==---------------------------------------------------------------------------==
18+
19+
20+
def apply():
21+
import vllm.model_executor.layers.quantization.auto_round as auto_round_module
22+
23+
from auto_round_extension.vllm_ext.auto_round_ext import AutoRoundExtensionConfig
24+
25+
auto_round_module.AutoRoundConfig = AutoRoundExtensionConfig
26+
from auto_round_extension.vllm_ext.envs_ext import extra_environment_variables
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#!/bin/bash
2+
3+
# Copyright (c) 2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
# Define the relative path for the `auto-round` installation
18+
AUTO_ROUND_PATH="auto_round/../auto_round_extension/vllm_ext/sitecustomize.py"
19+
20+
# Try to find the pip installation location
21+
PIP_LOCATION=$(pip show auto-round 2>/dev/null | grep "Location:" | awk '{print $2}')
22+
23+
if [ -n "$PIP_LOCATION" ]; then
24+
SITE_CUSTOMIZE_PATH="$PIP_LOCATION/$AUTO_ROUND_PATH"
25+
echo "Checking for sitecustomize.py at: $SITE_CUSTOMIZE_PATH"
26+
27+
if [ -f "$SITE_CUSTOMIZE_PATH" ]; then
28+
echo "Found sitecustomize.py at: $SITE_CUSTOMIZE_PATH"
29+
export PYTHONPATH=$(dirname "$SITE_CUSTOMIZE_PATH"):$PYTHONPATH
30+
echo "PYTHONPATH set to: $PYTHONPATH"
31+
return 0 2>/dev/null || true
32+
fi
33+
fi
34+
35+
# Fallback: check current directory
36+
LOCAL_SITE_CUSTOMIZE="./sitecustomize.py"
37+
if [ -f "$LOCAL_SITE_CUSTOMIZE" ]; then
38+
echo "Found sitecustomize.py at current directory."
39+
export PYTHONPATH=$(pwd):$PYTHONPATH
40+
echo "PYTHONPATH set to: $PYTHONPATH"
41+
return 0 2>/dev/null || true
42+
fi
43+
44+
echo "Warning: sitecustomize.py not found in pip installation or current directory."
45+
# Do not exit the shell
46+
return 1 2>/dev/null || true
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any
16+
17+
import torch
18+
from vllm.logger import init_logger
19+
from vllm.model_executor.layers.fused_moe import FusedMoE
20+
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
21+
from vllm.model_executor.layers.quantization.auto_round import AutoRoundConfig
22+
23+
from auto_round.schemes import QuantizationScheme
24+
from auto_round_extension.vllm_ext.quant_method_moe import AutoRoundMoEMethod
25+
26+
logger = init_logger(__name__)
27+
28+
29+
class AutoRoundExtensionConfig(AutoRoundConfig):
30+
SUPPORTED_DTYPES = AutoRoundConfig.SUPPORTED_DTYPES.union({"mx_fp"})
31+
SUPPORTED_FORMATS = AutoRoundConfig.SUPPORTED_FORMATS.union({"auto_round:llm_compressor"})
32+
33+
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
34+
# FIXME: (yi) make it compatible with `AutoRoundConfig`
35+
if isinstance(layer, FusedMoE):
36+
quant_method = AutoRoundMoEMethod.get_moe_method(self, layer, prefix)
37+
return quant_method
38+
elif isinstance(layer, LinearBase):
39+
return UnquantizedLinearMethod()
40+
else:
41+
return None
42+
43+
@staticmethod
44+
def _parse_quant_scheme(config: dict):
45+
quant_scheme_attrs = QuantizationScheme.get_attributes()
46+
filter_config = {key: value for key, value in config.items() if key in quant_scheme_attrs}
47+
quant_scheme = QuantizationScheme.from_dict(filter_config)
48+
return quant_scheme
49+
50+
@classmethod
51+
def from_config(cls, config: dict[str, Any]) -> AutoRoundConfig:
52+
ar_config = super().from_config(config)
53+
# TODO: (yi) refine below implementation
54+
quant_scheme = AutoRoundExtensionConfig._parse_quant_scheme(config)
55+
layer_schemes = {}
56+
layer_schemes = {} # ensure dict
57+
extra_config = getattr(ar_config, "extra_config", None)
58+
if extra_config is not None:
59+
for layer_name, layer_config in extra_config.items():
60+
layer_schemes[layer_name] = AutoRoundExtensionConfig._parse_quant_scheme(layer_config)
61+
ar_config.quant_scheme = quant_scheme
62+
ar_config.layer_schemes = layer_schemes
63+
return ar_config
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from typing import Any, Callable
17+
18+
from vllm.logger import init_logger
19+
20+
logger = init_logger(__name__)
21+
22+
# Define extra environment variables
23+
extra_environment_variables: dict[str, Callable[[], Any]] = {
24+
"VLLM_MXFP4_PRE_UNPACK_WEIGHTS": lambda: os.getenv("VLLM_MXFP4_PRE_UNPACK_WEIGHTS", "1") in ("1", "true", "True"),
25+
"VLLM_ENABLE_STATIC_MOE": lambda: os.getenv("VLLM_ENABLE_STATIC_MOE", "1") in ("1", "true", "True"),
26+
"VLLM_AR_MXFP4_MODULAR_MOE": lambda: os.getenv("VLLM_AR_MXFP4_MODULAR_MOE", "0") in ("1", "true", "True"),
27+
}
28+
# Add the extra environment variables to vllm.envs
29+
import vllm.envs as envs
30+
from vllm.envs import environment_variables
31+
32+
# Merge the environment variables
33+
all_environment_variables = {**environment_variables, **extra_environment_variables}
34+
35+
36+
for name, value_fn in extra_environment_variables.items():
37+
setattr(envs, name, value_fn())
38+
39+
logger.warning_once(f"Added extra environment variables: {list(extra_environment_variables.keys())}")
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) 2025 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch
18+
19+
# Module-level device tensor cache to fix cuda graph issue
20+
_DEVICE_E2M1_TENSORS = {}
21+
22+
# Constants for FP4 values (E2M1 format)
23+
_E2M1_VALUES = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0]
24+
25+
26+
def get_e2m1_tensor(device):
27+
"""Get device-specific E2M1 lookup tensor, creating it if needed."""
28+
device_str = str(device)
29+
if device_str not in _DEVICE_E2M1_TENSORS:
30+
_DEVICE_E2M1_TENSORS[device_str] = torch.tensor(_E2M1_VALUES, dtype=torch.float32, device=device)
31+
return _DEVICE_E2M1_TENSORS[device_str]
32+
33+
34+
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
35+
m, n = x.shape
36+
device = x.device
37+
38+
# Create lookup table for FP4 values to indices
39+
# Map the absolute values to 0-7 indices
40+
kE2M1 = get_e2m1_tensor(x.device)
41+
42+
# Find closest valid FP4 value index for each element
43+
abs_x = torch.abs(x)
44+
abs_diff_x = torch.abs(abs_x.unsqueeze(-1) - kE2M1) # [m, n, 8]
45+
abs_indices = torch.argmin(abs_diff_x, dim=-1) # [m, n]
46+
47+
# Apply sign bit (bit 3) to get final 4-bit representation
48+
indices = abs_indices + (torch.signbit(x).to(torch.long) << 3)
49+
50+
# Reshape to prepare for packing pairs of values
51+
indices = indices.reshape(-1)
52+
53+
# Handle odd length by padding if necessary
54+
assert indices.numel() % 2 != 0, f"Expected even number of elements, got {indices.numel()}"
55+
56+
# Reshape to pair consecutive elements
57+
indices = indices.reshape(-1, 2)
58+
59+
# Pack pairs of 4-bit values into 8-bit values
60+
packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
61+
62+
return packed.reshape(m, n // 2)
63+
64+
65+
def unpack_fp4_from_uint8(
66+
a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
67+
) -> torch.Tensor:
68+
"""
69+
Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
70+
(i.e. first four bits correspond to one fp4 value, last four correspond to a
71+
consecutive fp4 value). The bits represent an index, which are mapped to an fp4
72+
value.
73+
74+
:param a: tensor to unpack
75+
:param m: original dim 0 size of the unpacked tensor
76+
:param n: original dim 1 size of the unpacked tensor
77+
:param dtype: dense dtype to cast the unpacked tensor to
78+
"""
79+
assert a.dtype == torch.uint8, f"expected uint8, got {a.dtype}"
80+
81+
# Vectorized nibble processing
82+
a_flat = a.flatten()
83+
high = (a_flat & 0xF0) >> 4 # Upper nibbles
84+
low = a_flat & 0x0F # Lower nibbles
85+
86+
# Combine nibbles for batch processing
87+
combined = torch.stack((low, high), dim=1).flatten()
88+
89+
# Vectorized sign and magnitude extraction
90+
signs = (combined & 0x08).to(torch.bool) # Sign bits
91+
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
92+
93+
# Device-aware lookup and sign application
94+
kE2M1 = get_e2m1_tensor(a.device)
95+
96+
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
97+
98+
# Reshape to final form
99+
return values.reshape(m, n).to(dtype=dtype)
100+
101+
102+
def cast_to_fp4(x):
103+
sign = torch.sign(x)
104+
x = torch.abs(x)
105+
x[(x >= 0.0) & (x <= 0.25)] = 0.0
106+
x[(x > 0.25) & (x < 0.75)] = 0.5
107+
x[(x >= 0.75) & (x <= 1.25)] = 1.0
108+
x[(x > 1.25) & (x < 1.75)] = 1.5
109+
x[(x >= 1.75) & (x <= 2.5)] = 2.0
110+
x[(x > 2.5) & (x < 3.5)] = 3.0
111+
x[(x >= 3.5) & (x <= 5.0)] = 4.0
112+
x[x > 5.0] = 6.0
113+
return x * sign

0 commit comments

Comments
 (0)