Skip to content

Commit ac47153

Browse files
committed
added MXFP4 quantizer support to directly load GPT-OSS models via QEFFAutoModelForCausalLM (#577)
* added mxfp4 quantizer to match weights keys * added transform to dequantize mxfp4 to float32 --------- Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent b35d672 commit ac47153

File tree

7 files changed

+266
-28
lines changed

7 files changed

+266
-28
lines changed

QEfficient/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,13 @@
88
import os
99
import warnings
1010

11-
from QEfficient.utils import custom_format_warning
12-
1311
# For faster downloads via hf_transfer
1412
# This code is put above import statements as this needs to be executed before
1513
# hf_transfer is imported (will happen on line 15 via leading imports)
1614
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
1715
# Placeholder for all non-transformer models registered in QEfficient
1816
import QEfficient.utils.model_registery # noqa: F401
17+
from QEfficient.utils import custom_format_warning
1918
from QEfficient.utils.logging_utils import logger
2019

2120
# custom warning for the better logging experience

QEfficient/transformers/quantizers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
# SPDX-License-Identifier: BSD-3-Clause
55
#
66
# -----------------------------------------------------------------------------
7+
8+
from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers
9+
10+
__all__ = ["replace_transformers_quantizers"]

QEfficient/transformers/quantizers/auto.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from transformers.quantizers.quantizer_awq import AwqQuantizer
1010
from transformers.quantizers.quantizer_compressed_tensors import CompressedTensorsHfQuantizer
1111
from transformers.quantizers.quantizer_gptq import GptqHfQuantizer
12-
from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig
12+
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
13+
from transformers.utils.quantization_config import AwqConfig, CompressedTensorsConfig, GPTQConfig, Mxfp4Config
1314

1415
from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig, QEffAwqQuantizer
1516
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import (
@@ -19,30 +20,35 @@
1920
QEffFP8Quantizer,
2021
)
2122
from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig, QEffGPTQQuantizer
23+
from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4Config, QEffMxfp4HfQuantizer
2224

2325
QEFF_AUTO_QUANTIZER_MAPPING = {
2426
"awq": QEffAwqQuantizer,
2527
"gptq": QEffGPTQQuantizer,
2628
"compressed-tensors": QEffCompressedTensorsFP8Quantizer,
2729
"fp8": QEffFP8Quantizer,
30+
"mxfp4": QEffMxfp4HfQuantizer,
2831
}
2932
QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING = {
3033
"awq": QEffAwqConfig,
3134
"gptq": QEffGPTQConfig,
3235
"compressed-tensors": QEffCompressedTensorsConfig,
3336
"fp8": QEffFP8Config,
37+
"mxfp4": QEffMxfp4Config,
3438
}
3539
DUPLICATE_AUTO_QUANTIZER_MAPPING = {
3640
"awq": AwqQuantizer,
3741
"gptq": GptqHfQuantizer,
3842
"compressed-tensors": CompressedTensorsHfQuantizer,
3943
"fp8": None,
44+
"mxfp4": Mxfp4HfQuantizer,
4045
}
4146
DUPLICATE_AUTO_QUANTIZATION_CONFIG_MAPPING = {
4247
"awq": AwqConfig,
4348
"gptq": GPTQConfig,
4449
"compressed-tensors": CompressedTensorsConfig,
4550
"fp8": None,
51+
"mxfp4": Mxfp4Config,
4652
}
4753

4854

QEfficient/transformers/quantizers/quant_transforms.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,19 @@
77

88
import torch
99
from torch import nn
10+
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssExperts
1011

1112
from QEfficient.base.pytorch_transforms import ModuleMutatorTransform
1213
from QEfficient.customop.matmulnbits import QuantLinearORT
1314
from QEfficient.transformers.quantizers.awq import WQLinear_GEMM
1415
from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ
1516
from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear
16-
from QEfficient.transformers.quantizers.quantizer_utils import dequantize_gptq, unpack_weights
17+
from QEfficient.transformers.quantizers.quantizer_mxfp4 import QEffMxfp4GptOssExperts
18+
from QEfficient.transformers.quantizers.quantizer_utils import (
19+
convert_moe_packed_tensors,
20+
dequantize_gptq,
21+
unpack_weights,
22+
)
1723

1824

1925
class AwqToMatmulNbitsTransform(ModuleMutatorTransform):
@@ -115,3 +121,28 @@ def mutate(cls, original_module, parent_module):
115121
if original_module.bias is not None:
116122
dequant_linear_layer.bias = torch.nn.Parameter(original_module.bias.float())
117123
return dequant_linear_layer
124+
125+
126+
class Mxfp4GptOssExpertDequantizeTransform(ModuleMutatorTransform):
127+
"""
128+
Used to dequantize the weights of an Mxfp4GptOssExpert module and replace with transformers GptOssExperts with dequantized weights
129+
"""
130+
131+
_match_class = QEffMxfp4GptOssExperts
132+
133+
@classmethod
134+
def mutate(cls, original_module, parent_module):
135+
dequant_module = GptOssExperts(original_module.config)
136+
dequant_module.gate_up_proj = torch.nn.Parameter(
137+
convert_moe_packed_tensors(
138+
original_module.gate_up_proj_blocks, original_module.gate_up_proj_scales, dtype=torch.float32
139+
)
140+
)
141+
dequant_module.down_proj = torch.nn.Parameter(
142+
convert_moe_packed_tensors(
143+
original_module.down_proj_blocks, original_module.down_proj_scales, dtype=torch.float32
144+
)
145+
)
146+
dequant_module.gate_up_proj_bias = original_module.gate_up_proj_bias
147+
dequant_module.down_proj_bias = original_module.down_proj_bias
148+
return dequant_module
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import re
2+
from typing import Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
from transformers.quantizers.quantizer_mxfp4 import Mxfp4HfQuantizer
7+
from transformers.utils.quantization_config import Mxfp4Config
8+
9+
from QEfficient.transformers.quantizers.quantizer_utils import convert_moe_packed_tensors, get_keys_to_not_convert
10+
from QEfficient.utils.logging_utils import logger
11+
12+
13+
class QEffMxfp4GptOssExperts(nn.Module):
14+
def __init__(self, config):
15+
super().__init__()
16+
self.config = config
17+
18+
self.num_experts = config.num_local_experts
19+
self.intermediate_size = config.intermediate_size
20+
self.hidden_size = config.hidden_size
21+
22+
self.gate_up_proj_blocks = nn.Parameter(
23+
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
24+
requires_grad=False,
25+
)
26+
self.gate_up_proj_scales = nn.Parameter(
27+
torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
28+
requires_grad=False,
29+
)
30+
self.gate_up_proj_bias = nn.Parameter(
31+
torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
32+
)
33+
34+
self.down_proj_blocks = nn.Parameter(
35+
torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
36+
requires_grad=False,
37+
)
38+
self.down_proj_scales = nn.Parameter(
39+
torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
40+
requires_grad=False,
41+
)
42+
self.down_proj_bias = nn.Parameter(
43+
torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
44+
)
45+
self.alpha = 1.702
46+
self.limit = 7.0
47+
48+
self.gate_up_proj_precision_config = None
49+
self.down_proj_precision_config = None
50+
51+
def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor:
52+
gate_up_proj = convert_moe_packed_tensors(
53+
self.gate_up_proj_blocks, self.gate_up_proj_scales, dtype=torch.float32
54+
)
55+
down_proj = convert_moe_packed_tensors(self.down_proj_blocks, self.down_proj_scales, dtype=torch.float32)
56+
batch_size = hidden_states.shape[0]
57+
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
58+
num_experts = routing_weights.shape[1]
59+
hidden_states = hidden_states.repeat(num_experts, 1)
60+
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
61+
gate_up = torch.bmm(hidden_states, gate_up_proj) + self.gate_up_proj_bias[..., None, :]
62+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
63+
gate = gate.clamp(min=None, max=self.limit)
64+
up = up.clamp(min=-self.limit, max=self.limit)
65+
glu = gate * torch.sigmoid(gate * self.alpha)
66+
next_states = torch.bmm(((up + 1) * glu), down_proj)
67+
next_states = next_states + self.down_proj_bias[..., None, :]
68+
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
69+
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
70+
next_states = next_states.sum(dim=0)
71+
return next_states
72+
73+
74+
def should_convert_module(current_key_name, patterns):
75+
current_key_name_str = ".".join(current_key_name)
76+
if not any(
77+
re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
78+
):
79+
return True
80+
return False
81+
82+
83+
class QEffMxfp4Config(Mxfp4Config):
84+
"""
85+
Currently there is not need to change the implementation of Mxfp4Config
86+
This is placeholder for future when we would want to change this
87+
"""
88+
89+
pass
90+
91+
92+
class QEffMxfp4HfQuantizer(Mxfp4HfQuantizer):
93+
def validate_environment(self, *args, **kwargs):
94+
return True
95+
96+
def update_torch_dtype(self, torch_dtype):
97+
if torch_dtype not in [None, torch.float32]:
98+
logger.warning(f"Requested dtype {torch_dtype} is not supported, overriding to None")
99+
return None
100+
101+
def _process_model_before_weight_loading(
102+
self,
103+
model: torch.nn.Module,
104+
keep_in_fp32_modules: Optional[list[str]] = None,
105+
**kwargs,
106+
):
107+
self.modules_to_not_convert = get_keys_to_not_convert(model)
108+
self.modules_to_not_convert = (
109+
["lm_head"] if self.modules_to_not_convert is None else self.modules_to_not_convert
110+
)
111+
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert)
112+
self.modules_to_not_convert = list(set(self.modules_to_not_convert))
113+
config = model.config
114+
115+
# -- Defining local method as it uses lot of local variables --
116+
def _replace_with_mxfp4_linear(
117+
model,
118+
modules_to_not_convert=None,
119+
current_key_name=None,
120+
quantization_config=None,
121+
has_been_replaced=False,
122+
):
123+
if current_key_name is None:
124+
current_key_name = []
125+
126+
for name, module in model.named_children():
127+
current_key_name.append(name)
128+
if not should_convert_module(current_key_name, modules_to_not_convert):
129+
current_key_name.pop(-1)
130+
continue
131+
if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
132+
model._modules[name] = QEffMxfp4GptOssExperts(config)
133+
has_been_replaced = True
134+
if len(list(module.children())) > 0:
135+
_, has_been_replaced = _replace_with_mxfp4_linear(
136+
module,
137+
modules_to_not_convert,
138+
current_key_name,
139+
quantization_config,
140+
has_been_replaced=has_been_replaced,
141+
)
142+
current_key_name.pop(-1)
143+
return model, has_been_replaced
144+
145+
_replace_with_mxfp4_linear(
146+
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
147+
)
148+
model.config.quantization_config = self.quantization_config

QEfficient/transformers/quantizers/quantizer_utils.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,71 @@ def repack_zeros(qzeros, bits):
378378
break
379379
qzeros = qzeros.T
380380
return qzeros
381+
382+
383+
FP4_VALUES = [
384+
+0.0,
385+
+0.5,
386+
+1.0,
387+
+1.5,
388+
+2.0,
389+
+3.0,
390+
+4.0,
391+
+6.0,
392+
-0.0,
393+
-0.5,
394+
-1.0,
395+
-1.5,
396+
-2.0,
397+
-3.0,
398+
-4.0,
399+
-6.0,
400+
]
401+
402+
403+
def convert_moe_packed_tensors(
404+
blocks,
405+
scales,
406+
*,
407+
dtype: torch.dtype = torch.bfloat16,
408+
rows_per_chunk: int = 32768 * 1024,
409+
) -> torch.Tensor:
410+
"""
411+
reference for this function is taken from: https://github.com/huggingface/transformers/tree/main/src/transformers/models/gpt_oss#L98
412+
"""
413+
import math
414+
415+
scales = scales.to(torch.int32) - 127
416+
417+
assert blocks.shape[:-1] == scales.shape, f"{blocks.shape=} does not match {scales.shape=}"
418+
419+
lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
420+
421+
*prefix_shape, G, B = blocks.shape
422+
rows_total = math.prod(prefix_shape) * G
423+
424+
blocks = blocks.reshape(rows_total, B)
425+
scales = scales.reshape(rows_total, 1)
426+
427+
out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
428+
429+
for r0 in range(0, rows_total, rows_per_chunk):
430+
r1 = min(r0 + rows_per_chunk, rows_total)
431+
432+
blk = blocks[r0:r1]
433+
exp = scales[r0:r1]
434+
435+
# nibble indices -> int64
436+
idx_lo = (blk & 0x0F).to(torch.long)
437+
idx_hi = (blk >> 4).to(torch.long)
438+
439+
sub = out[r0:r1]
440+
sub[:, 0::2] = lut[idx_lo]
441+
sub[:, 1::2] = lut[idx_hi]
442+
443+
torch.ldexp(sub, exp, out=sub)
444+
del idx_lo, idx_hi, blk, exp
445+
446+
out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
447+
out = out.to(dtype).permute(0, 2, 1).contiguous()
448+
return out

examples/gpt_oss.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,41 +5,23 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
## BEFORE RUNNING PLS, RUN THE CONVERT SCRIPT TO CONVERT THE SAFETENSORS FROM FP4 to BF16
9-
## SEE DETAILS HERE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_oss/convert_gpt_oss_weights_to_hf.py
10-
## ONCE CONVERTED, PASS THE MODIFIED WEIGHTS TO THE MODEL_ID BELOW
11-
import torch
12-
from transformers import AutoConfig, GptOssForCausalLM, TextStreamer
8+
from transformers import AutoTokenizer, TextStreamer
139

1410
from QEfficient import QEFFAutoModelForCausalLM
15-
from QEfficient.utils._utils import load_hf_tokenizer
16-
from QEfficient.utils.constants import Constants
17-
from QEfficient.utils.run_utils import ApiRunner
1811

19-
torch.manual_seed(42)
20-
model_id = "CONVERTED_WEIGHTS" # See Comments above to convert saftensors to BF16
21-
config = AutoConfig.from_pretrained(model_id)
12+
model_id = "openai/gpt-oss-20b"
2213

23-
model = GptOssForCausalLM.from_pretrained(
24-
model_id, torch_dtype=torch.float32, attn_implementation="eager", config=config
25-
)
26-
model.eval()
27-
28-
tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id)
29-
config = model.config
30-
batch_size = len(Constants.INPUT_STR)
31-
32-
api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, Constants.PROMPT_LEN, Constants.CTX_LEN)
14+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
15+
tokenizer = AutoTokenizer.from_pretrained(model_id)
3316

34-
qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False)
3517
onnx_model_path = qeff_model.export()
3618
qpc_path = qeff_model.compile(
37-
prefill_seq_len=32,
19+
prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on.
3820
ctx_len=256,
3921
num_cores=16,
4022
mxfp6_matmul=True,
4123
mxint8_kv_cache=True,
42-
num_devices=4,
24+
num_devices=8,
4325
mos=1,
4426
aic_enable_depth_first=True,
4527
num_speculative_tokens=None,

0 commit comments

Comments
 (0)