Skip to content

Commit 2ae66e8

Browse files
committed
modify aciq observer and use aciq laplace for last 2 detr bbox embed weights
update readme
1 parent af750d9 commit 2ae66e8

File tree

5 files changed

+23
-55
lines changed

5 files changed

+23
-55
lines changed

examples/post_training_quantization/coco2017/DETR/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@ Since mask is not well supported by onnx, we removed mask-related codes and assi
1717

1818
|DETR-R50|mAPc|AP50|AP75| remarks|
1919
|-|-|-|-|-|
20-
|Float|0.421|0.623|0.443|baseline|
21-
|8w8f|0.332|0.588|0.320| Inputs of Add&LN not quantized.|
22-
|8w8f|0.395|0.607|0.409|Inputs of Add&LN not quantized. Float w&f for last bbox embed layer.|
23-
|8w8f|0.396|0.606|0.411|Inputs of Add&LN not quantized. Float w&f for last bbox&class embed layers.|
20+
|float|0.421 | 0.623 | 0.443 | baseline
21+
|8w8f|0.332|0.588|0.320| minmax observer|
22+
|8w8f|0.404|0.612|0.421| minmax observer, float w&f for last 2 bbox embed layers|
23+
|8w8f|0.384|0.598|0.402| minmax observer, apply aciq laplace observer for last bbox embed layer|
24+
|8w8f|0.398|0.609|0.420| minmax observer, apply aciq laplace observer for last 2 bbox embed layer|
2425

2526
TRT DETR w/ fixed input shape, enable int8&fp16 QPS: 118.334 on Nvidia 2080Ti. For detailed visualization, please refer to
2627
```shell

examples/post_training_quantization/coco2017/DETR/main.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,6 @@ def main():
151151
qmodel.calc_qparams()
152152
qmodel.set_quant(w_quant=True, a_quant=True)
153153

154-
qmodel.model.bbox_embed_layers_2.set_quant(w_quant=False, a_quant=False)
155-
156154
test_stats, coco_evaluator = evaluate(qmodel, criterion, postprocessors,
157155
data_loader_val, base_ds, device, args.output_dir)
158156

examples/post_training_quantization/coco2017/DETR/qconfig.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ W:
88
BIT: 8
99
OBSERVER:
1010
TYPE: MINMAX
11+
SPECIFIC: [{
12+
"bbox_embed_layers_1": ["OBSERVER.TYPE", "aciq", "OBSERVER.ACIQ.DISTRIBUTION", "laplace"],
13+
"bbox_embed_layers_2": ["OBSERVER.TYPE", "aciq", "OBSERVER.ACIQ.DISTRIBUTION", "laplace"]
14+
}]
1115
A:
1216
QSCHEME: per-tensor-symmetric
1317
QUANTIZER:

sparsebit/quantization/observers/aciq.py

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import torch
22
import math
3-
from .utils import mse_loss
43
from sparsebit.quantization.observers import Observer as BaseObserver
54
from sparsebit.quantization.observers import register_observer
65
from sparsebit.quantization.quantizers.quant_tensor import STE
@@ -56,6 +55,8 @@ def __init__(self, config, qdesc):
5655
8: 11.16,
5756
}
5857
self.gaus_const = (0.5 * 0.35) * (1 + (math.pi * math.log(4)) ** 0.5)
58+
self.distribution = config.OBSERVER.ACIQ.DISTRIBUTION.lower()
59+
assert self.distribution in ["gaus", "laplace"]
5960

6061
def calc_laplace_minmax(self, data, is_half_range):
6162
if self.is_perchannel:
@@ -115,53 +116,13 @@ def calc_minmax(self):
115116
data = self.get_calibration_data(c_first=True)
116117
is_half_range = data.min() >= 0
117118

118-
laplace_min_val, laplace_max_val = self.calc_laplace_minmax(data, is_half_range)
119-
scale_laplace, zero_point_laplace = self.calc_qparams_with_minmax(
120-
laplace_min_val, laplace_max_val
121-
)
122-
mse_laplace = mse_loss(
123-
STE.apply(
124-
data, scale_laplace, zero_point_laplace, self.qdesc, self.backend
125-
),
126-
data,
127-
self.is_perchannel,
128-
)
129-
130-
gaus_min_val, gaus_max_val = self.calc_gaus_minmax(
131-
data, batch_size, is_half_range
132-
)
133-
scale_gaus, zero_point_gaus = self.calc_qparams_with_minmax(
134-
gaus_min_val, gaus_max_val
135-
)
136-
137-
mse_gaus = mse_loss(
138-
STE.apply(data, scale_gaus, zero_point_gaus, self.qdesc, self.backend),
139-
data,
140-
self.is_perchannel,
141-
)
142-
143-
naive_min_val, naive_max_val = self.calc_naive_minmax(data)
144-
scale_minmax, zero_point_minmax = self.calc_qparams_with_minmax(
145-
naive_min_val, naive_max_val
146-
)
147-
mse_minmax = mse_loss(
148-
STE.apply(data, scale_minmax, zero_point_minmax, self.qdesc, self.backend),
149-
data,
150-
self.is_perchannel,
151-
)
152-
153-
mse_gaus_laplace = torch.minimum(mse_gaus, mse_laplace)
154-
self.min_val = torch.where(
155-
mse_gaus < mse_laplace, gaus_min_val, laplace_min_val
156-
)
157-
self.min_val = torch.where(
158-
mse_minmax < mse_gaus_laplace, naive_min_val, self.min_val
159-
).to(self.device)
160-
self.max_val = torch.where(
161-
mse_gaus < mse_laplace, gaus_max_val, laplace_max_val
162-
)
163-
self.max_val = torch.where(
164-
mse_minmax < mse_gaus_laplace, naive_max_val, self.max_val
165-
).to(self.device)
119+
if self.distribution == "laplace":
120+
min_val, max_val = self.calc_laplace_minmax(data, is_half_range)
121+
else:
122+
min_val, max_val = self.calc_gaus_minmax(
123+
data, batch_size, is_half_range
124+
)
125+
self.min_val = min_val.to(self.device)
126+
self.max_val = max_val.to(self.device)
166127

167128
return self.min_val, self.max_val

sparsebit/quantization/quant_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
_C.W.OBSERVER.TYPE = "MINMAX" # "MINMAX"/"MSE"/"PERCENTILE"/"KL_HISTOGRAM"
2121
_C.W.OBSERVER.PERCENTILE = CN()
2222
_C.W.OBSERVER.PERCENTILE.ALPHA = 0.001 # alpha for percentile observer
23+
_C.W.OBSERVER.ACIQ = CN()
24+
_C.W.OBSERVER.ACIQ.DISTRIBUTION = "GAUS" #"LAPLACE"/"GAUS"
2325
_C.W.SPECIFIC = []
2426

2527
_C.A = CN()
@@ -35,6 +37,8 @@
3537
_C.A.OBSERVER.PERCENTILE.ALPHA = 0.001 # alpha for percentile observer
3638
_C.A.OBSERVER.MOVING_AVERAGE = CN()
3739
_C.A.OBSERVER.MOVING_AVERAGE.EMA_RATIO = 0.9 # ema_ratio for moving_average observer
40+
_C.A.OBSERVER.ACIQ = CN()
41+
_C.A.OBSERVER.ACIQ.DISTRIBUTION = "GAUS" #"LAPLACE"/"GAUS"
3842
_C.A.OBSERVER.LAYOUT = "NCHW" # NCHW / NLC
3943
_C.A.SPECIFIC = []
4044

0 commit comments

Comments
 (0)