@@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config(
229
229
) -> QuantizationConfig :
230
230
extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
231
231
232
- act_quantization_spec = QuantizationSpec (
233
- dtype = torch .uint8 ,
234
- qscheme = (
235
- torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
236
- ),
237
- ch_axis = 0 ,
238
- observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
239
- )
232
+ if act_symmetric :
233
+ # If zero_point is 128, htp can do optimizations.
234
+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
235
+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
236
+ act_quantization_spec = QuantizationSpec (
237
+ dtype = torch .uint8 ,
238
+ qscheme = torch .per_tensor_symmetric ,
239
+ ch_axis = 0 ,
240
+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
241
+ )
242
+ else :
243
+ # PyTorch will remove redundant observers based on attributes such as:
244
+ # dtype, quant_min, quant_max, ch_axis, etc.
245
+ # Providing values like quant_min and quant_max can help observers compare
246
+ # and further reduce the number of observers.
247
+ act_quantization_spec = QuantizationSpec (
248
+ dtype = torch .uint8 ,
249
+ quant_min = torch .iinfo (torch .uint8 ).min ,
250
+ quant_max = torch .iinfo (torch .uint8 ).max ,
251
+ qscheme = torch .per_tensor_affine ,
252
+ ch_axis = 0 ,
253
+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
254
+ )
240
255
241
256
weight_quantization_spec = QuantizationSpec (
242
257
dtype = torch .int8 ,
@@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config(
409
424
quant_min = torch .iinfo (act_dtype ).min ,
410
425
quant_max = torch .iinfo (act_dtype ).max ,
411
426
qscheme = torch .per_tensor_affine ,
427
+ ch_axis = 0 ,
412
428
observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
413
429
)
414
430
0 commit comments