@@ -184,14 +184,29 @@ def get_default_8bit_qnn_ptq_config(
184
184
) -> QuantizationConfig :
185
185
extra_args : Dict [str , Any ] = {"eps" : 2 ** - 12 }
186
186
187
- act_quantization_spec = QuantizationSpec (
188
- dtype = torch .uint8 ,
189
- qscheme = (
190
- torch .per_tensor_symmetric if act_symmetric else torch .per_tensor_affine
191
- ),
192
- ch_axis = 0 ,
193
- observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
194
- )
187
+ if act_symmetric :
188
+ # If zero_point is 128, htp can do optimizations.
189
+ # If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
190
+ # If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
191
+ act_quantization_spec = QuantizationSpec (
192
+ dtype = torch .uint8 ,
193
+ qscheme = torch .per_tensor_symmetric ,
194
+ ch_axis = 0 ,
195
+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
196
+ )
197
+ else :
198
+ # PyTorch will remove redundant observers based on attributes such as:
199
+ # dtype, quant_min, quant_max, ch_axis, etc.
200
+ # Providing values like quant_min and quant_max can help observers compare
201
+ # and further reduce the number of observers.
202
+ act_quantization_spec = QuantizationSpec (
203
+ dtype = torch .uint8 ,
204
+ quant_min = torch .iinfo (torch .uint8 ).min ,
205
+ quant_max = torch .iinfo (torch .uint8 ).max ,
206
+ qscheme = torch .per_tensor_affine ,
207
+ ch_axis = 0 ,
208
+ observer_or_fake_quant_ctr = act_observer .with_args (** extra_args ),
209
+ )
195
210
196
211
weight_quantization_spec = QuantizationSpec (
197
212
dtype = torch .int8 ,
@@ -364,6 +379,7 @@ def get_ptq_per_channel_quant_config(
364
379
quant_min = torch .iinfo (act_dtype ).min ,
365
380
quant_max = torch .iinfo (act_dtype ).max ,
366
381
qscheme = torch .per_tensor_affine ,
382
+ ch_axis = 0 ,
367
383
observer_or_fake_quant_ctr = MovingAverageMinMaxObserver .with_args (** extra_args ),
368
384
)
369
385
0 commit comments