-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Yolox export & blade optimize support #49
Changes from all commits
ab26049
3340067
c6669fe
518cc29
57c73af
01cd311
237d8d8
7d99c93
b7a727a
2cb69cd
26d94db
db4928f
c268c56
c4f232c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
_base_ = './yolox_s_8xb16_300e_coco.py' | ||
|
||
# s m l x | ||
img_scale = (640, 640) | ||
random_size = (14, 26) | ||
scale_ratio = (0.1, 2) | ||
|
||
# tiny nano without mixup | ||
# img_scale = (416, 416) | ||
# random_size = (10, 20) | ||
# scale_ratio = (0.5, 1.5) | ||
|
||
CLASSES = [ | ||
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', | ||
'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', | ||
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' | ||
] | ||
|
||
# dataset settings | ||
data_root = 'data/voc/' | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
|
||
train_pipeline = [ | ||
dict(type='MMMosaic', img_scale=img_scale, pad_val=114.0), | ||
dict( | ||
type='MMRandomAffine', | ||
scaling_ratio_range=scale_ratio, | ||
border=(-img_scale[0] // 2, -img_scale[1] // 2)), | ||
dict( | ||
type='MMMixUp', # s m x l; tiny nano will detele | ||
img_scale=img_scale, | ||
ratio_range=(0.8, 1.6), | ||
pad_val=114.0), | ||
dict( | ||
type='MMPhotoMetricDistortion', | ||
brightness_delta=32, | ||
contrast_range=(0.5, 1.5), | ||
saturation_range=(0.5, 1.5), | ||
hue_delta=18), | ||
dict(type='MMRandomFlip', flip_ratio=0.5), | ||
dict(type='MMResize', keep_ratio=True), | ||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)), | ||
dict(type='MMNormalize', **img_norm_cfg), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) | ||
] | ||
test_pipeline = [ | ||
dict(type='MMResize', img_scale=img_scale, keep_ratio=True), | ||
dict(type='MMPad', pad_to_square=True, pad_val=(114.0, 114.0, 114.0)), | ||
dict(type='MMNormalize', **img_norm_cfg), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img']) | ||
] | ||
|
||
train_dataset = dict( | ||
type='DetImagesMixDataset', | ||
data_source=dict( | ||
type='DetSourceVOC', | ||
path=data_root + 'ImageSets/Main/train.txt', | ||
classes=CLASSES, | ||
cache_at_init=True), | ||
pipeline=train_pipeline, | ||
dynamic_scale=img_scale) | ||
|
||
val_dataset = dict( | ||
type='DetImagesMixDataset', | ||
imgs_per_gpu=2, | ||
data_source=dict( | ||
type='DetSourceVOC', | ||
path=data_root + 'ImageSets/Main/val.txt', | ||
classes=CLASSES, | ||
cache_at_init=True), | ||
pipeline=test_pipeline, | ||
dynamic_scale=None, | ||
label_padding=False) | ||
|
||
data = dict( | ||
imgs_per_gpu=16, workers_per_gpu=4, train=train_dataset, val=val_dataset) | ||
|
||
# # evaluation | ||
eval_pipelines = [ | ||
dict( | ||
mode='test', | ||
data=data['val'], | ||
evaluators=[dict(type='CocoDetectionEvaluator', classes=CLASSES)], | ||
) | ||
] | ||
|
||
export = dict(use_jit=True, export_blade=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,6 +62,7 @@ def __init__(self, | |
self.test_conf = test_conf | ||
self.nms_thre = nms_thre | ||
self.test_size = test_size | ||
self.traceable = False | ||
|
||
def forward_train(self, | ||
img: Tensor, | ||
|
@@ -165,3 +166,28 @@ def forward_compression(self, x): | |
outputs = self.head(fpn_outs) | ||
|
||
return outputs | ||
|
||
|
||
@MODELS.register_module | ||
class YOLOXExport(YOLOX): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove YOLOXExport There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using yoloXExport to fit the jit trace input/output requirenments, otherwise we should change the original YOLOX and add a config for it to open the post process. |
||
|
||
def __init__(self, *args, **kwargs): | ||
super(YOLOXExport, self).__init__(*args, **kwargs) | ||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
self.traceable = True | ||
self.backbone = self.backbone.to(device) | ||
self.head = self.head.to(device) | ||
for param in self.backbone.parameters(): | ||
param.requires_grad = False | ||
for param in self.head.parameters(): | ||
param.requires_grad = False | ||
|
||
def forward(self, img): | ||
with torch.no_grad(): | ||
|
||
fpn_outs = self.backbone(img) | ||
outputs = self.head(fpn_outs) | ||
|
||
outputs = postprocess(outputs, self.num_classes, self.test_conf, | ||
self.nms_thre) | ||
return outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove YOLOXExport api, should be a tool not a model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it may be different for different model, we use
yolox blade as name because the blacklist op for blade is specialized for yolox, whether
useful for other cnn or transformer need to be checked later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or support export mode in model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need fix