Skip to content

Commit 4a70ee0

Browse files
committed
WIP
WIP WIP WIP Can test version Can test version modify for dump onnx ready version ready version ready version ready version ready version ready version update_readme add detr qat example fix bugs for qat update readme reformat dirs reformat dirs reformat dirs update readme rectify sparsebit move qat to another branch modify aciq observer and use aciq laplace for last 2 detr bbox embed weights update readme modifications for hugging DETR simplify MR simplify MR add detr as submodule detr as submodule detr as submodule detr as submodule detr as submodule rm qdropout rm redundant clean-up rebase modifications rebase modifications rebase modifications not finished yet not finished yet wrong version write but low_acc version finished version finished version
1 parent fcd2abb commit 4a70ee0

22 files changed

+4946
-29
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "examples/quantization_aware_training/imagenet1k/deit/deit"]
22
path = examples/quantization_aware_training/imagenet1k/deit/deit
33
url = https://github.com/facebookresearch/deit.git
4+
[submodule "examples/post_training_quantization/coco2017/DETR/detr"]
5+
path = examples/post_training_quantization/coco2017/DETR/detr
6+
url = https://github.com/facebookresearch/detr.git

examples/post_training_quantization/coco2017/DETR/DETR_8w8f_visualization_mAP0399.svg

Lines changed: 4001 additions & 0 deletions
Loading
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# DETR PTQ example
2+
3+
## preparation
4+
5+
The `DETR` pretrained model is the checkpoint from https://github.com/facebookresearch/detr . The example will automatically download the checkpoint using `torch.hub.load`.
6+
7+
The datasets used in this example are train dataset and validation dataset of COCO2017. They can be downloaded from http://cocodataset.org. also the relative cocoapi should be installed.
8+
9+
## Usage
10+
11+
```shell
12+
python3 main.py qconfig.yaml --coco_path /path/to/coco
13+
```
14+
Since mask is not well supported by onnx, we removed mask-related codes and assign the batch size to be 1 only. Dynamic_axes for onnx is also not supported yet.
15+
16+
## Metrics
17+
18+
|DETR-R50|mAPc|AP50|AP75| remarks|
19+
|-|-|-|-|-|
20+
|float|0.421 | 0.623 | 0.443 | baseline
21+
|8w8f|0.332|0.588|0.320| minmax observer|
22+
|8w8f|0.404|0.612|0.421| minmax observer, float w&f for last 2 bbox embed layers|
23+
|8w8f|0.384|0.598|0.402| minmax observer, apply aciq laplace observer for last bbox embed layer|
24+
|8w8f|0.398|0.609|0.420| minmax observer, apply aciq laplace observer for last 2 bbox embed layer|
25+
26+
TRT DETR w/ fixed input shape, enable int8&fp16 QPS: 118.334 on Nvidia 2080Ti. For detailed visualization, please refer to
27+
```shell
28+
examples/post_training_quantization/coco2017/DETR/DETR_8w8f_visualization_mAP0395.svg
29+
```
Submodule detr added at 8a144f8
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import torch
2+
import os
3+
4+
import util.misc as utils
5+
from datasets.coco_eval import CocoEvaluator
6+
from datasets.panoptic_eval import PanopticEvaluator
7+
8+
9+
10+
11+
@torch.no_grad()
12+
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
13+
model.eval()
14+
criterion.eval()
15+
16+
metric_logger = utils.MetricLogger(delimiter=" ")
17+
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
18+
header = 'Test:'
19+
20+
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
21+
coco_evaluator = CocoEvaluator(base_ds, iou_types)
22+
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
23+
24+
panoptic_evaluator = None
25+
if 'panoptic' in postprocessors.keys():
26+
panoptic_evaluator = PanopticEvaluator(
27+
data_loader.dataset.ann_file,
28+
data_loader.dataset.ann_folder,
29+
output_dir=os.path.join(output_dir, "panoptic_eval"),
30+
)
31+
32+
for samples, targets in metric_logger.log_every(data_loader, 10, header):
33+
sample = samples.tensors.to(device)
34+
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
35+
36+
outputs = model(sample)
37+
loss_dict = criterion(outputs, targets)
38+
weight_dict = criterion.weight_dict
39+
40+
# reduce losses over all GPUs for logging purposes
41+
loss_dict_reduced = utils.reduce_dict(loss_dict)
42+
loss_dict_reduced_scaled = {k: v * weight_dict[k]
43+
for k, v in loss_dict_reduced.items() if k in weight_dict}
44+
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
45+
for k, v in loss_dict_reduced.items()}
46+
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
47+
**loss_dict_reduced_scaled,
48+
**loss_dict_reduced_unscaled)
49+
metric_logger.update(class_error=loss_dict_reduced['class_error'])
50+
51+
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
52+
results = postprocessors['bbox'](outputs, orig_target_sizes)
53+
if 'segm' in postprocessors.keys():
54+
target_sizes = torch.stack([t["size"] for t in targets], dim=0)
55+
results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
56+
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
57+
if coco_evaluator is not None:
58+
coco_evaluator.update(res)
59+
60+
if panoptic_evaluator is not None:
61+
res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
62+
for i, target in enumerate(targets):
63+
image_id = target["image_id"].item()
64+
file_name = f"{image_id:012d}.png"
65+
res_pano[i]["image_id"] = image_id
66+
res_pano[i]["file_name"] = file_name
67+
68+
panoptic_evaluator.update(res_pano)
69+
70+
# gather the stats from all processes
71+
metric_logger.synchronize_between_processes()
72+
print("Averaged stats:", metric_logger)
73+
if coco_evaluator is not None:
74+
coco_evaluator.synchronize_between_processes()
75+
if panoptic_evaluator is not None:
76+
panoptic_evaluator.synchronize_between_processes()
77+
78+
# accumulate predictions from all images
79+
if coco_evaluator is not None:
80+
coco_evaluator.accumulate()
81+
coco_evaluator.summarize()
82+
panoptic_res = None
83+
if panoptic_evaluator is not None:
84+
panoptic_res = panoptic_evaluator.summarize()
85+
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
86+
if coco_evaluator is not None:
87+
if 'bbox' in postprocessors.keys():
88+
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
89+
if 'segm' in postprocessors.keys():
90+
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
91+
if panoptic_res is not None:
92+
stats['PQ_all'] = panoptic_res["All"]
93+
stats['PQ_th'] = panoptic_res["Things"]
94+
stats['PQ_st'] = panoptic_res["Stuff"]
95+
return stats, coco_evaluator
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import argparse
2+
import torch
3+
import torch.nn as nn
4+
import torch.nn.parallel
5+
import torch.backends.cudnn as cudnn
6+
import torch.optim
7+
import torch.utils.data
8+
import torch.utils.data.distributed
9+
import detr.util.misc as utils
10+
import sys
11+
sys.path.append("./detr")
12+
from detr.datasets import get_coco_api_from_dataset
13+
from val_transform_datasets import build_dataset
14+
from model import build
15+
import onnx
16+
import onnx_graphsurgeon as gs
17+
18+
from sparsebit.quantization import QuantModel, parse_qconfig
19+
20+
from evaluation import evaluate
21+
22+
parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
23+
parser.add_argument("qconfig", help="the path of quant config")
24+
parser.add_argument(
25+
"-a",
26+
"--arch",
27+
metavar="ARCH",
28+
default="deit_tiny_patch16_224",
29+
help="ViT model architecture. (default: deit_tiny)",
30+
)
31+
parser.add_argument(
32+
"-j",
33+
"--num_workers",
34+
default=2,
35+
type=int,
36+
metavar="N",
37+
help="number of data loading workers (default: 4)",
38+
)
39+
parser.add_argument(
40+
"-b",
41+
"--batch-size",
42+
default=1,
43+
type=int,
44+
metavar="N",
45+
help="mini-batch size (default: 64), this is the total "
46+
"batch size of all GPUs on the current node when "
47+
"using Data Parallel or Distributed Data Parallel",
48+
)
49+
parser.add_argument(
50+
"-p",
51+
"--print-freq",
52+
default=10,
53+
type=int,
54+
metavar="N",
55+
help="print frequency (default: 10)",
56+
)
57+
58+
# * Backbone
59+
parser.add_argument('--backbone', default='resnet50', type=str,
60+
help="Name of the convolutional backbone to use")
61+
parser.add_argument('--dilation', action='store_true',
62+
help="If true, we replace stride with dilation in the last convolutional block (DC5)")
63+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
64+
help="Type of positional embedding to use on top of the image features")
65+
66+
67+
# * Transformer
68+
parser.add_argument('--enc_layers', default=6, type=int,
69+
help="Number of encoding layers in the transformer")
70+
parser.add_argument('--dec_layers', default=6, type=int,
71+
help="Number of decoding layers in the transformer")
72+
parser.add_argument('--dim_feedforward', default=2048, type=int,
73+
help="Intermediate size of the feedforward layers in the transformer blocks")
74+
parser.add_argument('--hidden_dim', default=256, type=int,
75+
help="Size of the embeddings (dimension of the transformer)")
76+
parser.add_argument('--dropout', default=0.1, type=float,
77+
help="Dropout applied in the transformer")
78+
parser.add_argument('--nheads', default=8, type=int,
79+
help="Number of attention heads inside the transformer's attentions")
80+
parser.add_argument('--num_queries', default=100, type=int,
81+
help="Number of query slots")
82+
parser.add_argument('--pre_norm', action='store_true')
83+
84+
# Loss
85+
parser.add_argument('--aux_loss', dest='aux_loss', action='store_true',
86+
help="Enables auxiliary decoding losses (loss at each layer)")
87+
# * Matcher
88+
parser.add_argument('--set_cost_class', default=1, type=float,
89+
help="Class coefficient in the matching cost")
90+
parser.add_argument('--set_cost_bbox', default=5, type=float,
91+
help="L1 box coefficient in the matching cost")
92+
parser.add_argument('--set_cost_giou', default=2, type=float,
93+
help="giou box coefficient in the matching cost")
94+
# * Loss coefficients
95+
parser.add_argument('--mask_loss_coef', default=1, type=float)
96+
parser.add_argument('--dice_loss_coef', default=1, type=float)
97+
parser.add_argument('--bbox_loss_coef', default=5, type=float)
98+
parser.add_argument('--giou_loss_coef', default=2, type=float)
99+
parser.add_argument('--eos_coef', default=0.1, type=float,
100+
help="Relative classification weight of the no-object class")
101+
102+
#configs for coco dataset
103+
parser.add_argument('--dataset_file', default='coco')
104+
parser.add_argument('--coco_path', type=str)
105+
parser.add_argument('--masks', action='store_true',
106+
help="Train segmentation head if the flag is provided")
107+
parser.add_argument('--output_dir', default='',
108+
help='path where to save, empty for no saving')
109+
110+
parser.add_argument('--device', default='cuda',
111+
help='device to use for training / testing')
112+
113+
def main():
114+
args = parser.parse_args()
115+
device = args.device
116+
117+
# get pretrained model from https://github.com/facebookresearch/detr
118+
model = torch.hub.load('facebookresearch/detr:main', 'detr_resnet50', pretrained=True)
119+
model, criterion, postprocessors = build(args, model)
120+
121+
qconfig = parse_qconfig(args.qconfig)
122+
qmodel = QuantModel(model, config=qconfig).to(device)
123+
124+
cudnn.benchmark = True
125+
126+
dataset_val = build_dataset(image_set='val', args=args)
127+
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
128+
data_loader_val = torch.utils.data.DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
129+
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
130+
base_ds = get_coco_api_from_dataset(dataset_val)
131+
132+
dataset_calib = build_dataset(image_set='train', args=args)
133+
sampler_calib = torch.utils.data.RandomSampler(dataset_calib)
134+
data_loader_calib = torch.utils.data.DataLoader(dataset_calib, args.batch_size, sampler=sampler_calib,
135+
drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
136+
137+
138+
qmodel.eval()
139+
with torch.no_grad():
140+
qmodel.prepare_calibration()
141+
# forward calibration-set
142+
calibration_size = 16
143+
cur_size = 0
144+
for samples, _ in data_loader_calib:
145+
sample = samples.tensors.to(device)
146+
qmodel(sample)
147+
cur_size += args.batch_size
148+
if cur_size >= calibration_size:
149+
break
150+
qmodel.calc_qparams()
151+
qmodel.set_quant(w_quant=True, a_quant=True)
152+
153+
test_stats, coco_evaluator = evaluate(qmodel, criterion, postprocessors,
154+
data_loader_val, base_ds, device, args.output_dir)
155+
156+
qmodel.export_onnx(torch.randn(1, 3, 800, 1200), name="qDETR.onnx")
157+
158+
# graph = gs.import_onnx(onnx.load("qDETR.onnx"))
159+
# Reshapes = [node for node in graph.nodes if node.op == "Reshape"]
160+
# for node in Reshapes:
161+
# if isinstance(node.inputs[1], gs.Constant):
162+
# if node.inputs[1].values[1]==7600:
163+
# node.inputs[1].values[1] = 8
164+
# elif node.inputs[1].values[1]==950:
165+
# node.inputs[1].values[1] = 1
166+
# elif node.inputs[1].values[1]==100:
167+
# node.inputs[1].values[1] = 1
168+
# elif node.inputs[1].values[1]==800:
169+
# node.inputs[1].values[1] = 8
170+
171+
# onnx.save(gs.export_onnx(graph), "qDETR.onnx")
172+
173+
174+
175+
176+
if __name__ == "__main__":
177+
main()

0 commit comments

Comments
 (0)