@@ -88,14 +88,23 @@ def __init__(
88
88
self .input_quant = self .quant_config .target_scheme_map ["Linear" ].get (
89
89
"input_activations" )
90
90
91
- if not (self .weight_quant .strategy == QuantizationStrategy .TENSOR
92
- and self .input_quant .strategy == QuantizationStrategy .TENSOR ):
91
+ per_tensor = (self .weight_quant .strategy == QuantizationStrategy .TENSOR
92
+ and self .input_quant .strategy
93
+ == QuantizationStrategy .TENSOR )
94
+ per_channel = (
95
+ self .weight_quant .strategy == QuantizationStrategy .CHANNEL
96
+ and self .input_quant .strategy == QuantizationStrategy .TOKEN )
97
+ if not (per_tensor or per_channel ):
93
98
raise ValueError (
94
- "For FP8 Fused MoE layers, only per-tensor scales "
95
- "for weights and activations are supported . Found "
99
+ "For FP8 Fused MoE layers, we require per tensor "
100
+ "or channelwise, dynamic per token quantization . Found "
96
101
f"{ self .weight_quant } , { self .input_quant } " )
97
102
98
103
self .static_input_scales = not self .input_quant .dynamic
104
+ if self .static_input_scales and per_channel :
105
+ raise ValueError (
106
+ "For FP8 Fused MoE layer, we require either per tensor or "
107
+ "channelwise, dynamic per token quantization." )
99
108
100
109
def create_weights (self , layer : torch .nn .Module , num_experts : int ,
101
110
hidden_size : int , intermediate_size_per_partition : int ,
@@ -123,24 +132,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
123
132
set_weight_attrs (w2_weight , extra_weight_attrs )
124
133
125
134
# WEIGHT_SCALES
126
- # Allocate 2 scales for w1 and w3 respectively.
127
- # They will be combined to a single scale after weight loading.
128
- w13_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
129
- 2 ,
130
- dtype = torch .float32 ),
131
- requires_grad = False )
132
- layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
135
+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
136
+ # Allocate 2 scales for w1 and w3 respectively.
137
+ # They are combined to a single scale after weight loading.
138
+ w13_weight_scale = torch .nn .Parameter (torch .ones (
139
+ num_experts , 2 , dtype = torch .float32 ),
140
+ requires_grad = False )
141
+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
142
+ w2_weight_scale = torch .nn .Parameter (torch .ones (
143
+ num_experts , dtype = torch .float32 ),
144
+ requires_grad = False )
145
+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
146
+ # Add PER-TENSOR quantization for FusedMoE.weight_loader.
147
+ extra_weight_attrs .update (
148
+ {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
149
+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
150
+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
133
151
134
- w2_weight_scale = torch .nn .Parameter (torch .ones (num_experts ,
135
- dtype = torch .float32 ),
136
- requires_grad = False )
137
- layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
138
- # Add the quantization method used (per tensor/grouped/channel)
139
- # to ensure the weight scales are loaded in properly
140
- extra_weight_attrs .update (
141
- {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
142
- set_weight_attrs (w13_weight_scale , extra_weight_attrs )
143
- set_weight_attrs (w2_weight_scale , extra_weight_attrs )
152
+ elif self .weight_quant .strategy == QuantizationStrategy .CHANNEL :
153
+ w13_weight_scale = torch .nn .Parameter (torch .ones (
154
+ num_experts ,
155
+ 2 * intermediate_size_per_partition ,
156
+ 1 ,
157
+ dtype = torch .float32 ),
158
+ requires_grad = False )
159
+ layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
160
+ w2_weight_scale = torch .nn .Parameter (torch .ones (
161
+ num_experts , hidden_size , 1 , dtype = torch .float32 ),
162
+ requires_grad = False )
163
+ layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
164
+ # Add PER-CHANNEL quantization for FusedMoE.weight_loader.
165
+ extra_weight_attrs .update (
166
+ {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
167
+ set_weight_attrs (w13_weight_scale , extra_weight_attrs )
168
+ set_weight_attrs (w2_weight_scale , extra_weight_attrs )
144
169
145
170
# INPUT_SCALES
146
171
if self .static_input_scales :
@@ -163,6 +188,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
163
188
# Fp8 moe kernels require a single activation scale.
164
189
# We take the max of all the scales in case they differ.
165
190
if self .static_input_scales :
191
+ assert self .input_quant .strategy == QuantizationStrategy .TENSOR
166
192
if (layer .w13_input_scale is None or layer .w2_input_scale is None ):
167
193
raise ValueError (
168
194
"QuantConfig has static quantization, but found "
@@ -204,24 +230,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
204
230
layer .w2_input_scale = torch .nn .Parameter (w2_input_scale ,
205
231
requires_grad = False )
206
232
207
- # Fp8 moe kernel needs single weight scale for w13 per expert.
208
- # We take the max then dequant and requant each expert.
209
- assert layer .w13_weight_scale is not None
210
- shard_size = layer .intermediate_size_per_partition
211
- max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
212
- for expert_id in range (layer .local_num_experts ):
213
- start = 0
214
- for shard_id in range (2 ):
215
- dq_weight = per_tensor_dequantize (
216
- layer .w13_weight [expert_id ][start :start + shard_size , :],
217
- layer .w13_weight_scale [expert_id ][shard_id ])
218
- layer .w13_weight [expert_id ][
219
- start :start + shard_size , :], _ = ops .scaled_fp8_quant (
220
- dq_weight , max_w13_scales [expert_id ])
221
- start += shard_size
222
-
223
- layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
224
- requires_grad = False )
233
+ # For Per-TENSOR case, Fp8 moe kernel needs single weight scale
234
+ # for w13 per expert. Use max then dequant and requant each expert.
235
+ if self .weight_quant .strategy == QuantizationStrategy .TENSOR :
236
+ assert layer .w13_weight_scale is not None
237
+ shard_size = layer .intermediate_size_per_partition
238
+ max_w13_scales = layer .w13_weight_scale .max (dim = 1 ).values
239
+ for expert_id in range (layer .local_num_experts ):
240
+ start = 0
241
+ for shard_id in range (2 ):
242
+ dq_weight = per_tensor_dequantize (
243
+ layer .w13_weight [expert_id ][start :start +
244
+ shard_size , :],
245
+ layer .w13_weight_scale [expert_id ][shard_id ])
246
+ layer .w13_weight [expert_id ][
247
+ start :start + shard_size , :], _ = ops .scaled_fp8_quant (
248
+ dq_weight , max_w13_scales [expert_id ])
249
+ start += shard_size
250
+ layer .w13_weight_scale = torch .nn .Parameter (max_w13_scales ,
251
+ requires_grad = False )
225
252
226
253
def apply (
227
254
self ,
@@ -265,6 +292,8 @@ def apply(
265
292
activation = activation ,
266
293
apply_router_weight_on_input = apply_router_weight_on_input ,
267
294
use_fp8_w8a8 = True ,
295
+ per_channel_quant = self .weight_quant .strategy ==
296
+ QuantizationStrategy .CHANNEL ,
268
297
global_num_experts = global_num_experts ,
269
298
expert_map = expert_map ,
270
299
w1_scale = layer .w13_weight_scale ,
0 commit comments