-
Notifications
You must be signed in to change notification settings - Fork 59
Add GPT-OSS quant support #887
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e255abb
4340b35
1882733
eb55c54
a4bd97f
2b9c015
30a560e
6707c34
03272f3
d25336c
6e27b7c
595ebfb
9a55217
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
|
|
||
| import torch | ||
| from torch import nn | ||
| from transformers.modeling_utils import no_init_weights as skip_weights_initialize | ||
| from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig | ||
| from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP | ||
|
|
||
| __all__ = ["get_replacement_info"] | ||
|
|
||
|
|
||
| def _update_parameter( | ||
| module: torch.nn.Module, | ||
| name: str, | ||
| data: torch.Tensor, | ||
| ) -> None: | ||
| param = getattr(module, name) | ||
| param.data.copy_(data) | ||
|
|
||
|
|
||
| class GPTOssSingleExpert(nn.Module): | ||
| def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None): | ||
| super().__init__() | ||
| self.hidden_size = hidden_size | ||
| self.intermediate_size = intermediate_size | ||
| self.alpha = 1.702 | ||
| self.limit = 7.0 | ||
| self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) | ||
| self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype) | ||
| self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| gate = self.gate_proj(x) | ||
| up = self.up_proj(x) | ||
| gate = gate.clamp(max=self.limit) | ||
| up = up.clamp(min=-self.limit, max=self.limit) | ||
| glu = gate * torch.sigmoid(gate * self.alpha) | ||
| act = (up + 1) * glu | ||
| return self.down_proj(act) | ||
|
|
||
|
|
||
| class SequentialGPTOSSMoE(nn.Module): | ||
| """ | ||
| Replaces GPT-OSS fused-expert MoE with per-expert `GPTOssSingleExpert` modules. | ||
| Copies weights from fused tensors and reuses the original router and optional shared_expert. | ||
| """ | ||
|
|
||
| def __init__(self, config: GptOssConfig, original: GptOssMLP): | ||
| super().__init__() | ||
| hidden_size = config.hidden_size | ||
| intermediate_size = config.intermediate_size | ||
| dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None) | ||
| dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32 | ||
| top_k = config.num_experts_per_tok | ||
| self.hidden_size = hidden_size | ||
| self.intermediate = intermediate_size | ||
| self.top_k = top_k | ||
| self.router = original.router | ||
| self.shared_expert = getattr(original, "shared_expert", None) | ||
|
|
||
| # Number of experts | ||
| E = original.experts.gate_up_proj.shape[0] | ||
| self.num_experts = E | ||
|
|
||
| # Build per-expert MLPs | ||
| self.experts = nn.ModuleList() | ||
| target_device = next(original.experts.parameters()).device | ||
| with skip_weights_initialize(), torch.device(target_device): | ||
| for _ in range(E): | ||
| self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype)) | ||
|
|
||
| gup = original.experts.gate_up_proj # [E, H, 2I] | ||
| gup_b = original.experts.gate_up_proj_bias # [E, 2I] | ||
| dwn = original.experts.down_proj # [E, I, H] | ||
| dwn_b = original.experts.down_proj_bias # [E, H] | ||
|
|
||
| for i, mlp in enumerate(self.experts): | ||
| _update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T) | ||
| _update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T) | ||
|
Comment on lines
+91
to
+92
|
||
| _update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T) | ||
|
|
||
| _update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2]) | ||
| _update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2]) | ||
| _update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H] | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | ||
| B, T, H = hidden_states.shape | ||
| x = hidden_states.reshape(-1, H) | ||
|
|
||
| # Use the original router (it returns scores and indices already softmaxed over top-k) | ||
| router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] | ||
|
|
||
| out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) | ||
|
|
||
| # Accumulate expert outputs for chosen experts only | ||
| for j in range(self.top_k): | ||
| idx = router_indices[:, j] | ||
| w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) | ||
| unique_experts = torch.unique(idx) | ||
| for e in unique_experts: | ||
| mask = idx == e | ||
| out[mask] += self.experts[e](x[mask]) * w[mask] | ||
|
|
||
| out = out.view(B, T, H) | ||
| router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder | ||
| return out, router_scores | ||
|
|
||
|
|
||
| def get_replacement_info(config): | ||
| return ( | ||
| SequentialGPTOSSMoE, | ||
| config.get_text_config(), | ||
| GptOssMLP.__name__, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright (c) 2025 Intel Corporation | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # Note: adapted from # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py | ||
|
|
||
| __all__ = ["get_replacement_info"] | ||
|
|
||
|
|
||
| import torch | ||
| from transformers.modeling_utils import no_init_weights | ||
| from transformers.models.llama4.modeling_llama4 import Llama4TextMLP | ||
|
|
||
|
|
||
| class SequentialLlama4TextExperts(torch.nn.ModuleList): | ||
| def __init__(self, config, original): | ||
| self.num_experts = original.gate_up_proj.shape[0] | ||
| with no_init_weights(): | ||
| super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) | ||
| intermediate_size = original.down_proj.shape[1] | ||
|
|
||
| for i in range(self.num_experts): | ||
| gate_up = original.gate_up_proj[i] | ||
| down = original.down_proj[i] | ||
| gate_proj = gate_up[:, :intermediate_size] | ||
| up_proj = gate_up[:, intermediate_size:] | ||
|
|
||
| self[i].gate_proj.weight.data = gate_proj.t().contiguous() | ||
| self[i].up_proj.weight.data = up_proj.t().contiguous() | ||
| self[i].down_proj.weight.data = down.t().contiguous() | ||
|
|
||
|
|
||
| class SequentialLlama4TextMoe(torch.nn.Module): | ||
| def __init__(self, config, original): | ||
| super().__init__() | ||
| self.top_k = config.num_experts_per_tok | ||
| self.hidden_dim = config.hidden_size | ||
| self.num_experts = config.num_local_experts | ||
| self.experts = SequentialLlama4TextExperts(config, original.experts) | ||
| self.router = original.router | ||
| self.shared_expert = original.shared_expert | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor): | ||
| hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||
| router_logits = self.router(hidden_states) | ||
| if isinstance(router_logits, tuple): | ||
| router_scores, router_logits = router_logits | ||
| router_scores = router_scores.t() | ||
| else: | ||
| # transformers < 4.54.0 only returns router_logits | ||
| router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) | ||
|
|
||
| router_scores = ( | ||
| torch.full_like(router_logits, float("-inf")) | ||
| .scatter_(1, router_indices, router_top_value) | ||
| .transpose(0, 1) | ||
| ) | ||
| router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) | ||
|
|
||
| out = self.shared_expert(hidden_states) | ||
| for i in range(self.num_experts): | ||
| out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) | ||
|
|
||
| return out, router_logits | ||
|
|
||
|
|
||
| def get_replacement_info(config): | ||
| return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,7 +11,8 @@ | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from auto_round.utils import logger | ||
| import auto_round.modelling as auto_round_modelling | ||
| from auto_round.utils import LazyImport, logger | ||
|
|
||
| mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size | ||
|
|
||
|
|
@@ -36,71 +37,24 @@ | |
| } | ||
| SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",) | ||
|
|
||
| CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4"] | ||
| CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better not to categorize it into too many detailed types. A single flag like model_need_to_convert, or a similar name, should be sufficient, since some models may require conversion even if they don’t have expert layers. We provide a converter function for each model if needed, regardless of which parts need to be converted.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I agree, the replacement code could be organized better. Once we support more model replacements, we can refactor that part as needed. For now, how about leaving it as is, since we have some higher-priority tasks to focus on?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don’t think it will take much effort to change. You could also finish the higher-priority tasks first.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Opened an issue to track #899 |
||
|
|
||
|
|
||
| def _get_moe_converter(config): | ||
| import torch | ||
| from transformers.modeling_utils import no_init_weights | ||
|
|
||
| # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py | ||
| if config.model_type == "llama4": | ||
| from transformers.models.llama4.modeling_llama4 import Llama4TextMLP | ||
|
|
||
| class SequentialLlama4TextExperts(torch.nn.ModuleList): | ||
| def __init__(self, config, original): | ||
| self.num_experts = original.gate_up_proj.shape[0] | ||
| with no_init_weights(): | ||
| super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) | ||
| intermediate_size = original.down_proj.shape[1] | ||
|
|
||
| for i in range(self.num_experts): | ||
| gate_up = original.gate_up_proj[i] | ||
| down = original.down_proj[i] | ||
| gate_proj = gate_up[:, :intermediate_size] | ||
| up_proj = gate_up[:, intermediate_size:] | ||
|
|
||
| self[i].gate_proj.weight.data = gate_proj.t().contiguous() | ||
| self[i].up_proj.weight.data = up_proj.t().contiguous() | ||
| self[i].down_proj.weight.data = down.t().contiguous() | ||
|
|
||
| class SequentialLlama4TextMoe(torch.nn.Module): | ||
| def __init__(self, config, original): | ||
| super().__init__() | ||
| self.top_k = config.num_experts_per_tok | ||
| self.hidden_dim = config.hidden_size | ||
| self.num_experts = config.num_local_experts | ||
| self.experts = SequentialLlama4TextExperts(config, original.experts) | ||
| self.router = original.router | ||
| self.shared_expert = original.shared_expert | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor): | ||
| hidden_states = hidden_states.reshape(-1, self.hidden_dim) | ||
| router_logits = self.router(hidden_states) | ||
| if isinstance(router_logits, tuple): | ||
| router_scores, router_logits = router_logits | ||
| router_scores = router_scores.t() | ||
| else: | ||
| # transformers < 4.54.0 only returns router_logits | ||
| router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) | ||
|
|
||
| router_scores = ( | ||
| torch.full_like(router_logits, float("-inf")) | ||
| .scatter_(1, router_indices, router_top_value) | ||
| .transpose(0, 1) | ||
| ) | ||
| router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) | ||
|
|
||
| out = self.shared_expert(hidden_states) | ||
| for i in range(self.num_experts): | ||
| out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) | ||
|
|
||
| return out, router_logits | ||
|
|
||
| return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe" | ||
|
|
||
| # Dispatch table for model_type to replacement_info functions | ||
| moe_converters = { | ||
| "gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"), | ||
| "llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"), | ||
| } | ||
|
|
||
| # Retrieve the appropriate function based on model_type | ||
| if config.model_type in moe_converters: | ||
| return moe_converters[config.model_type](config) | ||
| else: | ||
| raise ValueError(f"Currently moe converter only supports llama4 model_type, but get {config.model_type}") | ||
| raise ValueError( | ||
| f"Unsupported model_type '{config.model_type}'. " | ||
| f"Currently, MoE converter only supports: {', '.join(moe_converters.keys())}." | ||
| ) | ||
|
|
||
|
|
||
| def _handle_special_model(model): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,72 @@ | ||
| import pytest | ||
| from transformers import AutoConfig, AutoTokenizer | ||
| from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM | ||
|
|
||
| from auto_round import AutoRound | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def setup_gpt_oss(): | ||
| """Fixture to set up the GPT-OSS model and tokenizer.""" | ||
| model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16" | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| config.num_hidden_layers = 1 # Reduce layers for testing | ||
| model = GptOssForCausalLM(config) | ||
| output_dir = "/tmp/test_quantized_gpt_oss" | ||
| return model, tokenizer, output_dir, config | ||
yiliu30 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| def quantize_model(model, tokenizer, output_dir, scheme, iters=0): | ||
| """Helper function to quantize the model with the given scheme.""" | ||
| autoround = AutoRound( | ||
| model, | ||
| tokenizer, | ||
| scheme=scheme, | ||
| nsamples=2, | ||
| iters=iters, | ||
| fp_layers="self_attn,router,lm_head,mlp.gate", | ||
| ) | ||
| quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir) | ||
| return quantized_model | ||
|
|
||
|
|
||
| def count_modules_by_type(model, target_module_name_or_class): | ||
| """Helper function to count modules of a specific type in the model.""" | ||
| cnt = 0 | ||
| for name, module in model.named_modules(): | ||
| if isinstance(target_module_name_or_class, str): | ||
| if target_module_name_or_class == module.__class__.__name__: | ||
| cnt += 1 | ||
| else: | ||
| if isinstance(module, target_module_name_or_class): | ||
| cnt += 1 | ||
| return cnt | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("scheme", ["MXFP4", "MXFP8"]) | ||
| def test_quantization(setup_gpt_oss, scheme): | ||
| """Test quantization with the scheme.""" | ||
| model, tokenizer, output_dir, config = setup_gpt_oss | ||
| quantized_model = quantize_model(model, tokenizer, output_dir, scheme) | ||
|
|
||
| # Ensure the quantized model is not None | ||
| assert quantized_model is not None, "Quantized model should not be None." | ||
| from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear | ||
| from auto_round.modelling.gpt_oss import GPTOssSingleExpert | ||
|
|
||
| single_expert_cnt = count_modules_by_type(quantized_model, GPTOssSingleExpert) | ||
| quant_linear_cnt = count_modules_by_type(quantized_model, QuantLinear) | ||
| assert ( | ||
| single_expert_cnt == config.num_local_experts | ||
| ), f"Expected {config.num_local_experts} GPTOssSingleExpert modules, found {single_expert_cnt}." | ||
| assert ( | ||
| quant_linear_cnt == config.num_hidden_layers * 3 * config.num_local_experts | ||
| ), f"Expected {config.num_hidden_layers * 3 * config.num_local_experts} QuantLinear modules, found {quant_linear_cnt}." | ||
|
|
||
| print(f"[{scheme}] Total {GPTOssSingleExpert.__name__} modules: {single_expert_cnt}") | ||
| print(f"[{scheme}] Total {QuantLinear.__name__} modules: {quant_linear_cnt}") | ||
| # clean the output directory after test | ||
| import shutil | ||
|
|
||
| shutil.rmtree(output_dir, ignore_errors=True) | ||
Uh oh!
There was an error while loading. Please reload this page.