Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions auto_round/modelling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
127 changes: 127 additions & 0 deletions auto_round/modelling/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from torch import nn
from transformers.modeling_utils import no_init_weights as skip_weights_initialize
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssMLP

__all__ = ["get_replacement_info"]


def _update_parameter(
module: torch.nn.Module,
name: str,
data: torch.Tensor,
) -> None:
param = getattr(module, name)
param.data.copy_(data)


class GPTOssSingleExpert(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype | None = None):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.alpha = 1.702
self.limit = 7.0
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=True, dtype=dtype)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=True, dtype=dtype)

def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
gate = gate.clamp(max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
act = (up + 1) * glu
return self.down_proj(act)


class SequentialGPTOSSMoE(nn.Module):
"""
Replaces GPT-OSS fused-expert MoE with per-expert `GPTOssSingleExpert` modules.
Copies weights from fused tensors and reuses the original router and optional shared_expert.
"""

def __init__(self, config: GptOssConfig, original: GptOssMLP):
super().__init__()
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
dtype_str = getattr(config, "torch_dtype", None) or getattr(config, "dtype", None)
dtype = torch.bfloat16 if str(dtype_str).endswith("bfloat16") else torch.float32
top_k = config.num_experts_per_tok
self.hidden_size = hidden_size
self.intermediate = intermediate_size
self.top_k = top_k
self.router = original.router
self.shared_expert = getattr(original, "shared_expert", None)

# Number of experts
E = original.experts.gate_up_proj.shape[0]
self.num_experts = E

# Build per-expert MLPs
self.experts = nn.ModuleList()
target_device = next(original.experts.parameters()).device
with skip_weights_initialize(), torch.device(target_device):
for _ in range(E):
self.experts.append(GPTOssSingleExpert(hidden_size, intermediate_size, dtype=dtype))

gup = original.experts.gate_up_proj # [E, H, 2I]
gup_b = original.experts.gate_up_proj_bias # [E, 2I]
dwn = original.experts.down_proj # [E, I, H]
dwn_b = original.experts.down_proj_bias # [E, H]

for i, mlp in enumerate(self.experts):
_update_parameter(mlp.gate_proj, "weight", original.experts.gate_up_proj[i, :, ::2].T)
_update_parameter(mlp.up_proj, "weight", original.experts.gate_up_proj[i, :, 1::2].T)
Comment on lines +91 to +92
Copy link

Copilot AI Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The magic numbers ::2 and 1::2 for tensor slicing should be replaced with named constants like GATE_STRIDE = 2 and GATE_OFFSET = 0, UP_OFFSET = 1 to improve code readability and maintainability.

Copilot uses AI. Check for mistakes.
_update_parameter(mlp.down_proj, "weight", original.experts.down_proj[i].T)

_update_parameter(mlp.gate_proj, "bias", original.experts.gate_up_proj_bias[i, ::2])
_update_parameter(mlp.up_proj, "bias", original.experts.gate_up_proj_bias[i, 1::2])
_update_parameter(mlp.down_proj, "bias", original.experts.down_proj_bias[i]) # [H]

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, T, H = hidden_states.shape
x = hidden_states.reshape(-1, H)

# Use the original router (it returns scores and indices already softmaxed over top-k)
router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k]

out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)

# Accumulate expert outputs for chosen experts only
for j in range(self.top_k):
idx = router_indices[:, j]
w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1)
unique_experts = torch.unique(idx)
for e in unique_experts:
mask = idx == e
out[mask] += self.experts[e](x[mask]) * w[mask]

out = out.view(B, T, H)
router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder
return out, router_scores


def get_replacement_info(config):
return (
SequentialGPTOSSMoE,
config.get_text_config(),
GptOssMLP.__name__,
)
77 changes: 77 additions & 0 deletions auto_round/modelling/llama4.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: adapted from # https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py

__all__ = ["get_replacement_info"]


import torch
from transformers.modeling_utils import no_init_weights
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP


class SequentialLlama4TextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with no_init_weights():
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]
gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data = gate_proj.t().contiguous()
self[i].up_proj.weight.data = up_proj.t().contiguous()
self[i].down_proj.weight.data = down.t().contiguous()


class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert

def forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
if isinstance(router_logits, tuple):
router_scores, router_logits = router_logits
router_scores = router_scores.t()
else:
# transformers < 4.54.0 only returns router_logits
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)

router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)

return out, router_logits


def get_replacement_info(config):
return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe"
78 changes: 16 additions & 62 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_round.utils import logger
import auto_round.modelling as auto_round_modelling
from auto_round.utils import LazyImport, logger

mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size

Expand All @@ -36,71 +37,24 @@
}
SPECIAL_SHARED_CACHE_KEYS["MiniMaxText01ForCausalLM"] = ("slope_rate",)

CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4"]
CONVERT_EXPERT_TO_LINEAR_MODELS = ["llama4", "gpt_oss"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better not to categorize it into too many detailed types. A single flag like model_need_to_convert, or a similar name, should be sufficient, since some models may require conversion even if they don’t have expert layers. We provide a converter function for each model if needed, regardless of which parts need to be converted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree, the replacement code could be organized better. Once we support more model replacements, we can refactor that part as needed. For now, how about leaving it as is, since we have some higher-priority tasks to focus on?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think it will take much effort to change. You could also finish the higher-priority tasks first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened an issue to track #899



def _get_moe_converter(config):
import torch
from transformers.modeling_utils import no_init_weights

# https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py
if config.model_type == "llama4":
from transformers.models.llama4.modeling_llama4 import Llama4TextMLP

class SequentialLlama4TextExperts(torch.nn.ModuleList):
def __init__(self, config, original):
self.num_experts = original.gate_up_proj.shape[0]
with no_init_weights():
super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)])
intermediate_size = original.down_proj.shape[1]

for i in range(self.num_experts):
gate_up = original.gate_up_proj[i]
down = original.down_proj[i]
gate_proj = gate_up[:, :intermediate_size]
up_proj = gate_up[:, intermediate_size:]

self[i].gate_proj.weight.data = gate_proj.t().contiguous()
self[i].up_proj.weight.data = up_proj.t().contiguous()
self[i].down_proj.weight.data = down.t().contiguous()

class SequentialLlama4TextMoe(torch.nn.Module):
def __init__(self, config, original):
super().__init__()
self.top_k = config.num_experts_per_tok
self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.experts = SequentialLlama4TextExperts(config, original.experts)
self.router = original.router
self.shared_expert = original.shared_expert

def forward(self, hidden_states: torch.Tensor):
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
router_logits = self.router(hidden_states)
if isinstance(router_logits, tuple):
router_scores, router_logits = router_logits
router_scores = router_scores.t()
else:
# transformers < 4.54.0 only returns router_logits
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1)

router_scores = (
torch.full_like(router_logits, float("-inf"))
.scatter_(1, router_indices, router_top_value)
.transpose(0, 1)
)
router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype)

out = self.shared_expert(hidden_states)
for i in range(self.num_experts):
out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1)

return out, router_logits

return SequentialLlama4TextMoe, config.get_text_config(), "Llama4TextMoe"

# Dispatch table for model_type to replacement_info functions
moe_converters = {
"gpt_oss": LazyImport("auto_round.modelling.gpt_oss.get_replacement_info"),
"llama4": LazyImport("auto_round.modelling.llama4.get_replacement_info"),
}

# Retrieve the appropriate function based on model_type
if config.model_type in moe_converters:
return moe_converters[config.model_type](config)
else:
raise ValueError(f"Currently moe converter only supports llama4 model_type, but get {config.model_type}")
raise ValueError(
f"Unsupported model_type '{config.model_type}'. "
f"Currently, MoE converter only supports: {', '.join(moe_converters.keys())}."
)


def _handle_special_model(model):
Expand Down
2 changes: 1 addition & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def get_fp_layer_names(model, fp_layers):
for name in all_layer_names:
if fp_layer in name:
not_to_quantized_layers.append(name)

logger.trace(f"not_to_quantized_layers: {not_to_quantized_layers}")
return not_to_quantized_layers


Expand Down
72 changes: 72 additions & 0 deletions test/test_cpu/test_gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
from transformers import AutoConfig, AutoTokenizer
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

from auto_round import AutoRound


@pytest.fixture
def setup_gpt_oss():
"""Fixture to set up the GPT-OSS model and tokenizer."""
model_name = "/tf_dataset/auto_round/models/unsloth/gpt-oss-20b-BF16"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
config.num_hidden_layers = 1 # Reduce layers for testing
model = GptOssForCausalLM(config)
output_dir = "/tmp/test_quantized_gpt_oss"
return model, tokenizer, output_dir, config


def quantize_model(model, tokenizer, output_dir, scheme, iters=0):
"""Helper function to quantize the model with the given scheme."""
autoround = AutoRound(
model,
tokenizer,
scheme=scheme,
nsamples=2,
iters=iters,
fp_layers="self_attn,router,lm_head,mlp.gate",
)
quantized_model, save_folder = autoround.quantize_and_save(format="auto_round", output_dir=output_dir)
return quantized_model


def count_modules_by_type(model, target_module_name_or_class):
"""Helper function to count modules of a specific type in the model."""
cnt = 0
for name, module in model.named_modules():
if isinstance(target_module_name_or_class, str):
if target_module_name_or_class == module.__class__.__name__:
cnt += 1
else:
if isinstance(module, target_module_name_or_class):
cnt += 1
return cnt


@pytest.mark.parametrize("scheme", ["MXFP4", "MXFP8"])
def test_quantization(setup_gpt_oss, scheme):
"""Test quantization with the scheme."""
model, tokenizer, output_dir, config = setup_gpt_oss
quantized_model = quantize_model(model, tokenizer, output_dir, scheme)

# Ensure the quantized model is not None
assert quantized_model is not None, "Quantized model should not be None."
from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear
from auto_round.modelling.gpt_oss import GPTOssSingleExpert

single_expert_cnt = count_modules_by_type(quantized_model, GPTOssSingleExpert)
quant_linear_cnt = count_modules_by_type(quantized_model, QuantLinear)
assert (
single_expert_cnt == config.num_local_experts
), f"Expected {config.num_local_experts} GPTOssSingleExpert modules, found {single_expert_cnt}."
assert (
quant_linear_cnt == config.num_hidden_layers * 3 * config.num_local_experts
), f"Expected {config.num_hidden_layers * 3 * config.num_local_experts} QuantLinear modules, found {quant_linear_cnt}."

print(f"[{scheme}] Total {GPTOssSingleExpert.__name__} modules: {single_expert_cnt}")
print(f"[{scheme}] Total {QuantLinear.__name__} modules: {quant_linear_cnt}")
# clean the output directory after test
import shutil

shutil.rmtree(output_dir, ignore_errors=True)