Skip to content

Commit cb2a0e7

Browse files
authored
Qualcomm AI Engine Direct - Reduce redundant observers (#6351)
1 parent 4af687a commit cb2a0e7

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

backends/qualcomm/quantizer/utils.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,14 +229,29 @@ def get_default_8bit_qnn_ptq_config(
229229
) -> QuantizationConfig:
230230
extra_args: Dict[str, Any] = {"eps": 2**-12}
231231

232-
act_quantization_spec = QuantizationSpec(
233-
dtype=torch.uint8,
234-
qscheme=(
235-
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
236-
),
237-
ch_axis=0,
238-
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
239-
)
232+
if act_symmetric:
233+
# If zero_point is 128, htp can do optimizations.
234+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
235+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
236+
act_quantization_spec = QuantizationSpec(
237+
dtype=torch.uint8,
238+
qscheme=torch.per_tensor_symmetric,
239+
ch_axis=0,
240+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
241+
)
242+
else:
243+
# PyTorch will remove redundant observers based on attributes such as:
244+
# dtype, quant_min, quant_max, ch_axis, etc.
245+
# Providing values like quant_min and quant_max can help observers compare
246+
# and further reduce the number of observers.
247+
act_quantization_spec = QuantizationSpec(
248+
dtype=torch.uint8,
249+
quant_min=torch.iinfo(torch.uint8).min,
250+
quant_max=torch.iinfo(torch.uint8).max,
251+
qscheme=torch.per_tensor_affine,
252+
ch_axis=0,
253+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
254+
)
240255

241256
weight_quantization_spec = QuantizationSpec(
242257
dtype=torch.int8,
@@ -409,6 +424,7 @@ def get_ptq_per_channel_quant_config(
409424
quant_min=torch.iinfo(act_dtype).min,
410425
quant_max=torch.iinfo(act_dtype).max,
411426
qscheme=torch.per_tensor_affine,
427+
ch_axis=0,
412428
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
413429
)
414430

examples/qualcomm/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ def histogram(golden, predict):
348348
return (pa, mpa, miou, cls_iou)
349349

350350

351-
def get_imagenet_dataset(dataset_path, data_size, image_shape, crop_size=None):
351+
def get_imagenet_dataset(
352+
dataset_path, data_size, image_shape, crop_size=None, shuffle=True
353+
):
352354
from torchvision import datasets, transforms
353355

354356
def get_data_loader():
@@ -365,7 +367,7 @@ def get_data_loader():
365367
imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
366368
return torch.utils.data.DataLoader(
367369
imagenet_data,
368-
shuffle=True,
370+
shuffle=shuffle,
369371
)
370372

371373
# prepare input data

0 commit comments

Comments
 (0)