Skip to content

Commit d085a44

Browse files
authored
Enable PTPC FP8 for CompressedTensorsW8A8Fp8MoEMethod (triton fused_moe) (#16537)
Signed-off-by: mgoin <mgoin64@gmail.com>
1 parent f49e5af commit d085a44

File tree

1 file changed

+68
-39
lines changed

1 file changed

+68
-39
lines changed

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,23 @@ def __init__(
8888
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
8989
"input_activations")
9090

91-
if not (self.weight_quant.strategy == QuantizationStrategy.TENSOR
92-
and self.input_quant.strategy == QuantizationStrategy.TENSOR):
91+
per_tensor = (self.weight_quant.strategy == QuantizationStrategy.TENSOR
92+
and self.input_quant.strategy
93+
== QuantizationStrategy.TENSOR)
94+
per_channel = (
95+
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
96+
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
97+
if not (per_tensor or per_channel):
9398
raise ValueError(
94-
"For FP8 Fused MoE layers, only per-tensor scales "
95-
"for weights and activations are supported. Found "
99+
"For FP8 Fused MoE layers, we require per tensor "
100+
"or channelwise, dynamic per token quantization. Found "
96101
f"{self.weight_quant}, {self.input_quant}")
97102

98103
self.static_input_scales = not self.input_quant.dynamic
104+
if self.static_input_scales and per_channel:
105+
raise ValueError(
106+
"For FP8 Fused MoE layer, we require either per tensor or "
107+
"channelwise, dynamic per token quantization.")
99108

100109
def create_weights(self, layer: torch.nn.Module, num_experts: int,
101110
hidden_size: int, intermediate_size_per_partition: int,
@@ -123,24 +132,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
123132
set_weight_attrs(w2_weight, extra_weight_attrs)
124133

125134
# WEIGHT_SCALES
126-
# Allocate 2 scales for w1 and w3 respectively.
127-
# They will be combined to a single scale after weight loading.
128-
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
129-
2,
130-
dtype=torch.float32),
131-
requires_grad=False)
132-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
135+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
136+
# Allocate 2 scales for w1 and w3 respectively.
137+
# They are combined to a single scale after weight loading.
138+
w13_weight_scale = torch.nn.Parameter(torch.ones(
139+
num_experts, 2, dtype=torch.float32),
140+
requires_grad=False)
141+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
142+
w2_weight_scale = torch.nn.Parameter(torch.ones(
143+
num_experts, dtype=torch.float32),
144+
requires_grad=False)
145+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
146+
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
147+
extra_weight_attrs.update(
148+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
149+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
150+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
133151

134-
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
135-
dtype=torch.float32),
136-
requires_grad=False)
137-
layer.register_parameter("w2_weight_scale", w2_weight_scale)
138-
# Add the quantization method used (per tensor/grouped/channel)
139-
# to ensure the weight scales are loaded in properly
140-
extra_weight_attrs.update(
141-
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
142-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
143-
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
152+
elif self.weight_quant.strategy == QuantizationStrategy.CHANNEL:
153+
w13_weight_scale = torch.nn.Parameter(torch.ones(
154+
num_experts,
155+
2 * intermediate_size_per_partition,
156+
1,
157+
dtype=torch.float32),
158+
requires_grad=False)
159+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
160+
w2_weight_scale = torch.nn.Parameter(torch.ones(
161+
num_experts, hidden_size, 1, dtype=torch.float32),
162+
requires_grad=False)
163+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
164+
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
165+
extra_weight_attrs.update(
166+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
167+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
168+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
144169

145170
# INPUT_SCALES
146171
if self.static_input_scales:
@@ -163,6 +188,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
163188
# Fp8 moe kernels require a single activation scale.
164189
# We take the max of all the scales in case they differ.
165190
if self.static_input_scales:
191+
assert self.input_quant.strategy == QuantizationStrategy.TENSOR
166192
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
167193
raise ValueError(
168194
"QuantConfig has static quantization, but found "
@@ -204,24 +230,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
204230
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
205231
requires_grad=False)
206232

207-
# Fp8 moe kernel needs single weight scale for w13 per expert.
208-
# We take the max then dequant and requant each expert.
209-
assert layer.w13_weight_scale is not None
210-
shard_size = layer.intermediate_size_per_partition
211-
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
212-
for expert_id in range(layer.local_num_experts):
213-
start = 0
214-
for shard_id in range(2):
215-
dq_weight = per_tensor_dequantize(
216-
layer.w13_weight[expert_id][start:start + shard_size, :],
217-
layer.w13_weight_scale[expert_id][shard_id])
218-
layer.w13_weight[expert_id][
219-
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
220-
dq_weight, max_w13_scales[expert_id])
221-
start += shard_size
222-
223-
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
224-
requires_grad=False)
233+
# For Per-TENSOR case, Fp8 moe kernel needs single weight scale
234+
# for w13 per expert. Use max then dequant and requant each expert.
235+
if self.weight_quant.strategy == QuantizationStrategy.TENSOR:
236+
assert layer.w13_weight_scale is not None
237+
shard_size = layer.intermediate_size_per_partition
238+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
239+
for expert_id in range(layer.local_num_experts):
240+
start = 0
241+
for shard_id in range(2):
242+
dq_weight = per_tensor_dequantize(
243+
layer.w13_weight[expert_id][start:start +
244+
shard_size, :],
245+
layer.w13_weight_scale[expert_id][shard_id])
246+
layer.w13_weight[expert_id][
247+
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
248+
dq_weight, max_w13_scales[expert_id])
249+
start += shard_size
250+
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
251+
requires_grad=False)
225252

226253
def apply(
227254
self,
@@ -265,6 +292,8 @@ def apply(
265292
activation=activation,
266293
apply_router_weight_on_input=apply_router_weight_on_input,
267294
use_fp8_w8a8=True,
295+
per_channel_quant=self.weight_quant.strategy ==
296+
QuantizationStrategy.CHANNEL,
268297
global_num_experts=global_num_experts,
269298
expert_map=expert_map,
270299
w1_scale=layer.w13_weight_scale,

0 commit comments

Comments
 (0)