Skip to content

Commit e6024c9

Browse files
committed
Qualcomm AI Engine Direct - Reduce redundant observers
1 parent ad0e5e8 commit e6024c9

File tree

2 files changed

+28
-10
lines changed

2 files changed

+28
-10
lines changed

backends/qualcomm/quantizer/utils.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,29 @@ def get_default_8bit_qnn_ptq_config(
184184
) -> QuantizationConfig:
185185
extra_args: Dict[str, Any] = {"eps": 2**-12}
186186

187-
act_quantization_spec = QuantizationSpec(
188-
dtype=torch.uint8,
189-
qscheme=(
190-
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
191-
),
192-
ch_axis=0,
193-
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
194-
)
187+
if act_symmetric:
188+
# If zero_point is 128, htp can do optimizations.
189+
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
190+
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
191+
act_quantization_spec = QuantizationSpec(
192+
dtype=torch.uint8,
193+
qscheme=torch.per_tensor_symmetric,
194+
ch_axis=0,
195+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
196+
)
197+
else:
198+
# PyTorch will remove redundant observers based on attributes such as:
199+
# dtype, quant_min, quant_max, ch_axis, etc.
200+
# Providing values like quant_min and quant_max can help observers compare
201+
# and further reduce the number of observers.
202+
act_quantization_spec = QuantizationSpec(
203+
dtype=torch.uint8,
204+
quant_min=torch.iinfo(torch.uint8).min,
205+
quant_max=torch.iinfo(torch.uint8).max,
206+
qscheme=torch.per_tensor_affine,
207+
ch_axis=0,
208+
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
209+
)
195210

196211
weight_quantization_spec = QuantizationSpec(
197212
dtype=torch.int8,
@@ -364,6 +379,7 @@ def get_ptq_per_channel_quant_config(
364379
quant_min=torch.iinfo(act_dtype).min,
365380
quant_max=torch.iinfo(act_dtype).max,
366381
qscheme=torch.per_tensor_affine,
382+
ch_axis=0,
367383
observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(**extra_args),
368384
)
369385

examples/qualcomm/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ def histogram(golden, predict):
348348
return (pa, mpa, miou, cls_iou)
349349

350350

351-
def get_imagenet_dataset(dataset_path, data_size, image_shape, crop_size=None):
351+
def get_imagenet_dataset(
352+
dataset_path, data_size, image_shape, crop_size=None, shuffle=True
353+
):
352354
from torchvision import datasets, transforms
353355

354356
def get_data_loader():
@@ -365,7 +367,7 @@ def get_data_loader():
365367
imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess)
366368
return torch.utils.data.DataLoader(
367369
imagenet_data,
368-
shuffle=True,
370+
shuffle=shuffle,
369371
)
370372

371373
# prepare input data

0 commit comments

Comments
 (0)