Skip to content

Commit ee755a4

Browse files
authored
[Qwen3 Next] Update qwen3_next to use moe_calibration_context (#1984)
# SUMMARY: - Update the qwen3_next_moe definition to use the moe context - Update example to use the correct arguement - Add tests - Update qwen3_moe model definition to implement `restore` # Testing: - All modeling tests pass
1 parent c254c19 commit ee755a4

File tree

6 files changed

+140
-14
lines changed

6 files changed

+140
-14
lines changed

examples/quantization_w4a4_fp4/qwen3_next_example.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,22 @@ def tokenize(sample):
6868
)
6969

7070
# Apply quantization.
71-
# We see `calibrate_moe_context` to True to update all `Qwen3MoeSparseMoeBlock`
72-
# during calibration.
71+
# MoE calibration is now handled automatically by the pipeline.
72+
# We set `moe_calibrate_all_experts` to True to ensure all experts receive
73+
# calibration data. This temporarily updates the model definition to use
74+
# `CalibrationQwen3NextSparseMoeBlock` (from `llmcompressor.modeling.qwen3_next_moe`)
75+
# which replaces the original `Qwen3NextSparseMoeBlock` class.
76+
# This updates how the forward pass is handled in the MoE block during calibration.
7377
# Feel free to update the definition under
74-
# llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with
75-
# this behaviour and evaluate its impact on quantization performance
78+
# llm-compressor/src/llmcompressor/modeling/qwen3_next_moe.py to play around with
79+
# this behavior and evaluate its impact on quantization performance.
7680
oneshot(
7781
model=model,
7882
dataset=ds,
7983
recipe=recipe,
8084
max_seq_length=MAX_SEQUENCE_LENGTH,
8185
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
82-
calibrate_moe_context=True,
86+
moe_calibrate_all_experts=True,
8387
)
8488

8589

src/llmcompressor/modeling/moe_context.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class MoECalibrationModule(ABC, torch.nn.Module):
4545

4646
is_permanent: bool = False
4747

48-
def restore(self) -> torch.nn.Module:
48+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
4949
"""
5050
Restore the original module structure.
5151
@@ -163,5 +163,5 @@ def moe_calibration_context(
163163
# Step 2: Restore non-permanent modules
164164
for name, (original, replacement) in replaced.items():
165165
if not replacement.is_permanent:
166-
restored = replacement.restore()
166+
restored = replacement.restore(original)
167167
model.set_submodule(name, restored)

src/llmcompressor/modeling/prepare.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from llmcompressor.modeling.qwen3_moe import ( # noqa: F401
3030
CalibrationQwen3MoeSparseMoeBlock,
3131
)
32+
from llmcompressor.modeling.qwen3_next_moe import ( # noqa: F401
33+
CalibrationQwen3NextSparseMoeBlock,
34+
)
3235
from llmcompressor.modeling.qwen3_vl_moe import (
3336
replace as replace_Qwen3VLMoE,
3437
)

src/llmcompressor/modeling/qwen3_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,9 @@ def forward(self, hidden_states: torch.Tensor):
9898
)
9999
return final_hidden_states, router_logits
100100

101+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
102+
return original
103+
101104

102105
# Legacy function for backward compatibility
103106
def replace(

src/llmcompressor/modeling/qwen3_next_moe.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,34 @@
1616

1717
import torch
1818

19+
from llmcompressor.modeling.moe_context import (
20+
MoECalibrationModule,
21+
register_moe_calibration,
22+
)
23+
24+
25+
@register_moe_calibration("Qwen3NextSparseMoeBlock")
26+
class CalibrationQwen3NextSparseMoeBlock(MoECalibrationModule):
27+
from transformers import Qwen3NextConfig
28+
from transformers.models.qwen3_next.modeling_qwen3_next import (
29+
Qwen3NextSparseMoeBlock,
30+
)
31+
32+
"""
33+
Calibration version of Qwen3NextSparseMoeBlock that sends all tokens to all experts.
34+
"""
35+
36+
is_permanent = False
1937

20-
class Qwen3NextSparseMoeBlock(torch.nn.Module):
2138
def __init__(
2239
self,
23-
config,
24-
original,
25-
calibrate_all_experts: bool,
40+
original: Qwen3NextSparseMoeBlock,
41+
config: Qwen3NextConfig,
42+
calibrate_all_experts: bool = True,
2643
):
2744
super().__init__()
2845
self.num_experts = config.num_experts
29-
self.top_k = config.top_k
46+
self.top_k = config.num_experts_per_tok
3047
self.norm_topk_prob = config.norm_topk_prob
3148

3249
# gating
@@ -44,7 +61,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
4461
router_logits = self.gate(hidden_states)
4562

4663
routing_weights = torch.nn.functional.softmax(
47-
router_logits, dim=1, dtype=torch.float
64+
router_logits, dim=-1, dtype=torch.float
4865
)
4966
routing_weights, selected_experts = torch.topk(
5067
routing_weights, self.top_k, dim=-1
@@ -103,12 +120,15 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
103120
)
104121
return final_hidden_states, router_logits
105122

123+
def restore(self, original: torch.nn.Module) -> torch.nn.Module:
124+
return original
125+
106126

107127
def replace(
108128
config,
109129
module,
110130
calibrate_all_experts,
111131
):
112-
return Qwen3NextSparseMoeBlock(
132+
return CalibrationQwen3NextSparseMoeBlock(
113133
config=config, original=module, calibrate_all_experts=calibrate_all_experts
114134
)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import contextlib
2+
from functools import partial
3+
4+
import pytest
5+
import torch
6+
from transformers import AutoModelForCausalLM
7+
8+
from llmcompressor.modeling.moe_context import moe_calibration_context
9+
from llmcompressor.modeling.qwen3_next_moe import CalibrationQwen3NextSparseMoeBlock
10+
from llmcompressor.utils.dev import skip_weights_download
11+
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
12+
from tests.testing_utils import requires_cadence, requires_gpu
13+
14+
15+
@requires_cadence("weekly")
16+
@pytest.mark.parametrize("model_stub", ["Qwen/Qwen3-Next-80B-A3B-Instruct"])
17+
def test_calib_replace_qwen3moe_all_experts(model_stub):
18+
with skip_weights_download():
19+
model = AutoModelForCausalLM.from_pretrained(model_stub)
20+
21+
# Qwen3MoE layer replacement is temporary within the context
22+
with contextlib.ExitStack() as stack:
23+
stack.enter_context(calibration_forward_context(model))
24+
stack.enter_context(DisableQuantization(model))
25+
stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True))
26+
27+
# Find one MoE layer
28+
moe_layer = None
29+
for name, module in model.named_modules():
30+
if isinstance(module, CalibrationQwen3NextSparseMoeBlock):
31+
moe_layer = module
32+
break
33+
34+
assert moe_layer is not None
35+
36+
num_experts = len(moe_layer.experts)
37+
expert_triggered = [False for _ in range(num_experts)]
38+
39+
# Define the hook function
40+
def hook_fn(i, module, input, output):
41+
expert_triggered[i] = True
42+
43+
# Attach hooks using functools.partial to bind each index
44+
for i, expert in enumerate(moe_layer.experts):
45+
expert.register_forward_hook(partial(hook_fn, i))
46+
47+
# Create dummy input tensor that simulates hidden_states
48+
hidden_dim = model.config.hidden_size
49+
batch, seq_len = 4, 32
50+
sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32)
51+
52+
# Forward through the MoE layer directly
53+
with torch.no_grad():
54+
_ = moe_layer(sample)
55+
56+
# Assert all experts are used
57+
assert all(
58+
expert_triggered
59+
), f"Not all experts were triggered: {expert_triggered}"
60+
61+
62+
@requires_gpu
63+
def test_calib_qwen3_moe_module():
64+
from transformers import Qwen3NextConfig
65+
from transformers.models.qwen3_next.modeling_qwen3_next import (
66+
Qwen3NextSparseMoeBlock,
67+
)
68+
69+
config = Qwen3NextConfig()
70+
with torch.device("cuda"):
71+
original = Qwen3NextSparseMoeBlock(config).eval()
72+
73+
# Create dummy input tensor that simulates hidden_states
74+
hidden_dim = config.hidden_size
75+
batch, seq_len = 4, 32
76+
sample = torch.randn(batch, seq_len, hidden_dim, device="cuda")
77+
78+
with calibration_forward_context(original):
79+
true_output = original(sample)
80+
81+
module = CalibrationQwen3NextSparseMoeBlock(
82+
original, config, calibrate_all_experts=True
83+
)
84+
85+
with calibration_forward_context(module):
86+
output = module(sample)
87+
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
88+
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10
89+
90+
module = CalibrationQwen3NextSparseMoeBlock(
91+
original, config, calibrate_all_experts=False
92+
)
93+
with calibration_forward_context(module):
94+
output = module(sample)
95+
assert torch.nn.functional.mse_loss(true_output[0], output[0]) < 1e-10
96+
assert torch.nn.functional.mse_loss(true_output[1], output[1]) < 1e-10

0 commit comments

Comments
 (0)