Skip to content

Commit

Permalink
add rtdetr final (#8094)
Browse files Browse the repository at this point in the history
* [exp] add r50vd in dino

add yoloe reader

alter reference points to unsigmoid

fix amp training

alter usage in paddle-inference

update new base

alter ext_ops

add hybrid encoder

* add pp rt-detr

---------

Co-authored-by: ghostxsl <451323469@qq.com>
  • Loading branch information
lyuwenyu and ghostxsl authored Apr 18, 2023
1 parent 92752b0 commit 5d1f888
Show file tree
Hide file tree
Showing 15 changed files with 1,189 additions and 40 deletions.
41 changes: 41 additions & 0 deletions configs/rtdetr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# DETRs Beat YOLOs on Real-time Object Detection

## Introduction
We propose a **R**eal-**T**ime **DE**tection **TR**ansformer (RT-DETR), the first real-time end-to-end object detector to our best knowledge. Specifically, we design an efficient hybrid encoder to efficiently process multi-scale features by decoupling the intra-scale interaction and cross-scale fusion, and propose IoU-aware query selection to improve the initialization of object queries. In addition, our proposed detector supports flexibly adjustment of the inference speed by using different decoder layers without the need for retraining, which facilitates the practical application of real-time object detectors. Our RT-DETR-L achieves 53.0% AP on COCO val2017 and 114 FPS on T4 GPU, while RT-DETR-X achieves 54.8% AP and 74 FPS, outperforming all YOLO detectors of the same scale in both speed and accuracy. Furthermore, our RT-DETR-R50 achieves 53.1% AP and 108 FPS, outperforming DINO-Deformable-DETR-R50 by 2.2% AP in accuracy and by about 21 times in FPS. For more details, please refer to our [paper](https://arxiv.org/abs/2304.08069).

<div align="center">
<img src="https://user-images.githubusercontent.com/17582080/232390925-54e58fe6-1c17-4610-90b9-7e5525577d80.png" width=500 />
</div>


## Model Zoo

### Model Zoo on COCO

| Model | Epoch | backbone | input shape | $AP^{val}$ | $AP^{val}_{50}$| Params(M) | FLOPs(G) | T4 TensorRT FP16(FPS) | Pretrained Model | config |
|:--------------:|:-----:|:----------:| :-------:|:--------------------------:|:---------------------------:|:---------:|:--------:| :---------------------: |:------------------------------------------------------------------------------------:|:-------------------------------------------:|
| RT-DETR-R50 | 80 | ResNet-50 | 640 | 53.1 | 71.3 | 42 | 136 | 108 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r50vd_6x_coco.pdparams) | [config](./rtdetr_r50vd_6x_coco.yml)
| RT-DETR-R101 | 80 | ResNet-101 | 640 | 54.3 | 72.7 | 76 | 259 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_r101vd_6x_coco.pdparams) | [config](./rtdetr_r101vd_6x_coco.yml)
| RT-DETR-L | 80 | HGNetv2 | 640 | 53.0 | 71.6 | 32 | 110 | 114 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_l_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_l_6x_coco.yml)
| RT-DETR-X | 80 | HGNetv2 | 640 | 54.8 | 73.1 | 67 | 234 | 74 | [download](https://bj.bcebos.com/v1/paddledet/models/rtdetr_hgnetv2_x_6x_coco.pdparams) | [comming soon](rtdetr_hgnetv2_x_6x_coco.yml)

**Notes:**
- RT-DETR uses 4GPU to train.
- RT-DETR is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.

GPU multi-card training
```bash
python -m paddle.distributed.launch --gpus 0,1,2,3 tools/train.py -c configs/rtdetr/rtdetr_r50vd_6x_coco.yml --fleet --eval
```

## Citations
```
@misc{lv2023detrs,
title={DETRs Beat YOLOs on Real-time Object Detection},
author={Wenyu Lv and Shangliang Xu and Yian Zhao and Guanzhong Wang and Jinman Wei and Cheng Cui and Yuning Du and Qingqing Dang and Yi Liu},
year={2023},
eprint={2304.08069},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
19 changes: 19 additions & 0 deletions configs/rtdetr/_base_/optimizer_6x.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
epoch: 72

LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 1.0
milestones: [100]
use_warmup: true
- !LinearWarmup
start_factor: 0.001
steps: 2000

OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
71 changes: 71 additions & 0 deletions configs/rtdetr/_base_/rtdetr_r50vd.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams
norm_type: sync_bn
use_ema: True
ema_decay: 0.9999
ema_decay_type: "exponential"
ema_filter_no_grad: True
hidden_dim: 256
use_focal_loss: True
eval_size: [640, 640]


DETR:
backbone: ResNet
neck: HybridEncoder
transformer: RTDETRTransformer
detr_head: DINOHead
post_process: DETRPostProcess

ResNet:
# index 0 stands for res2
depth: 50
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.1, 0.1, 0.1, 0.1]
num_stages: 4
freeze_stem_only: True

HybridEncoder:
hidden_dim: 256
use_encoder_idx: [2]
num_encoder_layers: 1
encoder_layer:
name: TransformerLayer
d_model: 256
nhead: 8
dim_feedforward: 1024
dropout: 0.
activation: 'gelu'
expansion: 1.0


RTDETRTransformer:
num_queries: 300
position_embed_type: sine
feat_strides: [8, 16, 32]
num_levels: 3
nhead: 8
num_decoder_layers: 6
dim_feedforward: 1024
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: False

DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
use_vfl: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}

DETRPostProcess:
num_top_queries: 300
43 changes: 43 additions & 0 deletions configs/rtdetr/_base_/rtdetr_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
worker_num: 4
TrainReader:
sample_transforms:
- Decode: {}
- RandomDistort: {prob: 0.8}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomCrop: {prob: 0.8}
- RandomFlip: {}
batch_transforms:
- BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_size: 4
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 4
shuffle: false
drop_last: false


TestReader:
inputs_def:
image_shape: [3, 640, 640]
sample_transforms:
- Decode: {}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
- NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none}
- Permute: {}
batch_size: 1
shuffle: false
drop_last: false
37 changes: 37 additions & 0 deletions configs/rtdetr/rtdetr_r101vd_6x_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]

weights: output/rtdetr_r101vd_6x_coco/model_final
find_unused_parameters: True
log_iter: 200

pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet101_vd_ssld_pretrained.pdparams

ResNet:
# index 0 stands for res2
depth: 101
variant: d
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.01, 0.01, 0.01, 0.01]
num_stages: 4
freeze_stem_only: True

HybridEncoder:
hidden_dim: 384
use_encoder_idx: [2]
num_encoder_layers: 1
encoder_layer:
name: TransformerLayer
d_model: 384
nhead: 8
dim_feedforward: 2048
dropout: 0.
activation: 'gelu'
expansion: 1.0
11 changes: 11 additions & 0 deletions configs/rtdetr/rtdetr_r50vd_6x_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_6x.yml',
'_base_/rtdetr_r50vd.yml',
'_base_/rtdetr_reader.yml',
]

weights: output/rtdetr_r50vd_6x_coco/model_final
find_unused_parameters: True
log_iter: 200
30 changes: 16 additions & 14 deletions ppdet/data/transform/batch_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def __call__(self, samples, context=None):
@register_op
class PadMaskBatch(BaseOperator):
"""
Pad a batch of samples so they can be divisible by a stride.
Pad a batch of samples so that they can be divisible by a stride.
The layout of each image should be 'CHW'.
Args:
pad_to_stride (int): If `pad_to_stride > 0`, pad zeros to ensure
Expand All @@ -959,7 +959,7 @@ class PadMaskBatch(BaseOperator):
`pad_mask` for transformer.
"""

def __init__(self, pad_to_stride=0, return_pad_mask=False):
def __init__(self, pad_to_stride=0, return_pad_mask=True):
super(PadMaskBatch, self).__init__()
self.pad_to_stride = pad_to_stride
self.return_pad_mask = return_pad_mask
Expand All @@ -984,7 +984,7 @@ def __call__(self, samples, context=None):
im_c, im_h, im_w = im.shape[:]
padding_im = np.zeros(
(im_c, max_shape[1], max_shape[2]), dtype=np.float32)
padding_im[:, :im_h, :im_w] = im
padding_im[:, :im_h, :im_w] = im.astype(np.float32)
data['image'] = padding_im
if 'semantic' in data and data['semantic'] is not None:
semantic = data['semantic']
Expand Down Expand Up @@ -1108,12 +1108,13 @@ def __init__(self, return_gt_mask=True, pad_img=False, minimum_gtnum=0):
self.pad_img = pad_img
self.minimum_gtnum = minimum_gtnum

def _impad(self, img: np.ndarray,
*,
shape = None,
padding = None,
pad_val = 0,
padding_mode = 'constant') -> np.ndarray:
def _impad(self,
img: np.ndarray,
*,
shape=None,
padding=None,
pad_val=0,
padding_mode='constant') -> np.ndarray:
"""Pad the given image to a certain shape or pad on all sides with
specified padding mode and padding value.
Expand Down Expand Up @@ -1169,7 +1170,7 @@ def _impad(self, img: np.ndarray,
padding = (padding, padding, padding, padding)
else:
raise ValueError('Padding must be a int or a 2, or 4 element tuple.'
f'But received {padding}')
f'But received {padding}')

# check padding mode
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric']
Expand All @@ -1194,10 +1195,10 @@ def _impad(self, img: np.ndarray,
def checkmaxshape(self, samples):
maxh, maxw = 0, 0
for sample in samples:
h,w = sample['im_shape']
if h>maxh:
h, w = sample['im_shape']
if h > maxh:
maxh = h
if w>maxw:
if w > maxw:
maxw = w
return (maxh, maxw)

Expand Down Expand Up @@ -1246,7 +1247,8 @@ def __call__(self, samples, context=None):
sample['difficult'] = pad_diff
if 'gt_joints' in sample:
num_joints = sample['gt_joints'].shape[1]
pad_gt_joints = np.zeros((num_max_boxes, num_joints, 3), dtype=np.float32)
pad_gt_joints = np.zeros(
(num_max_boxes, num_joints, 3), dtype=np.float32)
if num_gt > 0:
pad_gt_joints[:num_gt] = sample['gt_joints']
sample['gt_joints'] = pad_gt_joints
Expand Down
22 changes: 16 additions & 6 deletions ppdet/data/transform/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,8 @@ def __init__(self,
brightness=[0.5, 1.5, 0.5],
random_apply=True,
count=4,
random_channel=False):
random_channel=False,
prob=1.0):
super(RandomDistort, self).__init__()
self.hue = hue
self.saturation = saturation
Expand All @@ -510,6 +511,7 @@ def __init__(self,
self.random_apply = random_apply
self.count = count
self.random_channel = random_channel
self.prob = prob

def apply_hue(self, img):
low, high, prob = self.hue
Expand Down Expand Up @@ -563,6 +565,8 @@ def apply_brightness(self, img):
return img

def apply(self, sample, context=None):
if random.random() > self.prob:
return sample
img = sample['image']
if self.random_apply:
functions = [
Expand Down Expand Up @@ -1488,7 +1492,8 @@ def __init__(self,
allow_no_crop=True,
cover_all_box=False,
is_mask_crop=False,
ioumode="iou"):
ioumode="iou",
prob=1.0):
super(RandomCrop, self).__init__()
self.aspect_ratio = aspect_ratio
self.thresholds = thresholds
Expand All @@ -1498,6 +1503,7 @@ def __init__(self,
self.cover_all_box = cover_all_box
self.is_mask_crop = is_mask_crop
self.ioumode = ioumode
self.prob = prob

def crop_segms(self, segms, valid_ids, crop, height, width):
def _crop_poly(segm, crop):
Expand Down Expand Up @@ -1588,6 +1594,9 @@ def set_fake_bboxes(self, sample):
return sample

def apply(self, sample, context=None):
if random.random() > self.prob:
return sample

if 'gt_bbox' not in sample:
# only used in semi-det as unsup data
sample = self.set_fake_bboxes(sample)
Expand Down Expand Up @@ -2829,22 +2838,23 @@ def __init__(self,

def get_size_with_aspect_ratio(self, image_shape, size, max_size=None):
h, w = image_shape
max_clip = False
if max_size is not None:
min_original_size = float(min((w, h)))
max_original_size = float(max((w, h)))
if max_original_size / min_original_size * size > max_size:
size = int(
round(max_size * min_original_size / max_original_size))
size = int(max_size * min_original_size / max_original_size)
max_clip = True

if (w <= h and w == size) or (h <= w and h == size):
return (w, h)

if w < h:
ow = size
oh = int(round(size * h / w))
oh = int(round(size * h / w)) if not max_clip else max_size
else:
oh = size
ow = int(round(size * w / h))
ow = int(round(size * w / h)) if not max_clip else max_size

return (ow, oh)

Expand Down
Loading

0 comments on commit 5d1f888

Please sign in to comment.