forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Refactor]: Refactor DAB-DETR in MMDetection 3.x (open-mmlab#9252)
* resolve refactor conflict w.o. pre-commit hooks * fixed error and finished alignment * supprot 91 cls and remove batch_first * delete iter_update, keep return intermediate * delete posHW * substitute 'num_query' to 'num_queries' * change 'gen_sineembed_for_position' to 'convert_coordinate_to_encoding' * resolve extra comments * fix error * fix error * fix data path * support 91 cls temporarily * resolve extra comments * fix num_keys, num_feats * delete reg_branches in decoder_inputs_dict * fix docstring * fix docstring * commit modification in pr of DINO * fix data format from nbc to bnc in detr and deformable-detr * fix 'gen_encoder_output_proposals' for two-stage deformable-detr * fix 'gen_encoder_output_proposals' for two-stage deformable-detr * set 'batch_first' to True in deformable attention * fix error * fix ut * add assert for batch_first * remove 91 cls * modify pre_decoder of DeformableDETR * delete useless comments * bnc data flow w.o. merge detr and def-detr * assert batch first flag in conditional attention, fix error * add unit test for dab-detr * fix doc * disable yapf hook * move conditional attention to trm/layers * fix name and add doc * fix doc * add loss_and_predict for head * fix doc and typehint * fix doc and typehint * modify batch first assert for attention * change Dab to DAB * rename file and function * make dab-detr head inherit conditional detr head * fix doc * fix doc Co-authored-by: QingyunLi <962537281@qq.com>
- Loading branch information
Showing
17 changed files
with
1,124 additions
and
301 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# DAB-DETR | ||
|
||
> [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
We present in this paper a novel query formulation using dynamic anchor boxes for DETR (DEtection TRansformer) and offer a deeper understanding of the role of queries in DETR. This new formulation directly uses box coordinates as queries in Transformer decoders and dynamically updates them layer-by-layer. Using box coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR, but also allows us to modulate the positional attention map using the box width and height information. Such a design makes it clear that queries in DETR can be implemented as performing soft ROI pooling layer-by-layer in a cascade manner. As a result, it leads to the best performance on MS-COCO benchmark among the DETR-like detection models under the same setting, e.g., AP 45.7% using ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive experiments to confirm our analysis and verify the effectiveness of our methods. | ||
|
||
<div align=center> | ||
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/arch.png?raw=true"/> | ||
</div> | ||
<div align=center> | ||
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/model.png?raw=true"/> | ||
</div> | ||
<div align=center> | ||
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/results.png?raw=true"/> | ||
</div> | ||
|
||
## Results and Models | ||
|
||
We provide the config files and models for DAB-DETR: [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329). | ||
|
||
| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download | | ||
| :------: | :------: | :-----: | :------: | :------------: | :----: | :---------------------------------------: | :----------------------------------: | | ||
| R-50 | DAB-DETR | 50e | 6.4 | | 42.3 | [config](./dab-detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) | | ||
|
||
## Citation | ||
|
||
```latex | ||
@inproceedings{ | ||
liu2022dabdetr, | ||
title={{DAB}-{DETR}: Dynamic Anchor Boxes are Better Queries for {DETR}}, | ||
author={Shilong Liu and Feng Li and Hao Zhang and Xiao Yang and Xianbiao Qi and Hang Su and Jun Zhu and Lei Zhang}, | ||
booktitle={International Conference on Learning Representations}, | ||
year={2022}, | ||
url={https://openreview.net/forum?id=oMI9PjOb9Jl} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
_base_ = [ | ||
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py' | ||
] | ||
model = dict( | ||
type='DABDETR', | ||
num_queries=300, | ||
with_random_refpoints=False, | ||
num_patterns=0, | ||
data_preprocessor=dict( | ||
type='DetDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True, | ||
pad_size_divisor=1), | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(3, ), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch', | ||
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), | ||
neck=dict( | ||
type='ChannelMapper', | ||
in_channels=[2048], | ||
kernel_size=1, | ||
out_channels=256, | ||
act_cfg=None, | ||
norm_cfg=None, | ||
num_outs=1), | ||
encoder=dict( | ||
num_layers=6, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict( | ||
embed_dims=256, num_heads=8, dropout=0., batch_first=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, | ||
num_fcs=2, | ||
ffn_drop=0., | ||
act_cfg=dict(type='PReLU')))), | ||
decoder=dict( | ||
num_layers=6, | ||
query_dim=4, | ||
query_scale_type='cond_elewise', | ||
with_modulated_hw_attn=True, | ||
layer_cfg=dict( | ||
self_attn_cfg=dict( | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0., | ||
proj_drop=0., | ||
cross_attn=False), | ||
cross_attn_cfg=dict( | ||
embed_dims=256, | ||
num_heads=8, | ||
attn_drop=0., | ||
proj_drop=0., | ||
cross_attn=True), | ||
ffn_cfg=dict( | ||
embed_dims=256, | ||
feedforward_channels=2048, | ||
num_fcs=2, | ||
ffn_drop=0., | ||
act_cfg=dict(type='PReLU'))), | ||
return_intermediate=True), | ||
positional_encoding_cfg=dict( | ||
num_feats=128, temperature=20, normalize=True), | ||
bbox_head=dict( | ||
type='DABDETRHead', | ||
num_classes=80, | ||
embed_dims=256, | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='L1Loss', loss_weight=5.0), | ||
loss_iou=dict(type='GIoULoss', loss_weight=2.0)), | ||
# training and testing settings | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='HungarianAssigner', | ||
match_costs=[ | ||
dict(type='FocalLossCost', weight=2., eps=1e-8), | ||
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), | ||
dict(type='IoUCost', iou_mode='giou', weight=2.0) | ||
])), | ||
test_cfg=dict(max_per_img=300)) | ||
|
||
# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different | ||
# from the default setting in mmdet. | ||
train_pipeline = [ | ||
dict( | ||
type='LoadImageFromFile', | ||
file_client_args={{_base_.file_client_args}}), | ||
dict(type='LoadAnnotations', with_bbox=True), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict( | ||
type='RandomChoice', | ||
transforms=[[ | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), | ||
(608, 1333), (640, 1333), (672, 1333), (704, 1333), | ||
(736, 1333), (768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
], | ||
[ | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(400, 1333), (500, 1333), (600, 1333)], | ||
keep_ratio=True), | ||
dict( | ||
type='RandomCrop', | ||
crop_type='absolute_range', | ||
crop_size=(384, 600), | ||
allow_negative_crop=True), | ||
dict( | ||
type='RandomChoiceResize', | ||
scales=[(480, 1333), (512, 1333), (544, 1333), | ||
(576, 1333), (608, 1333), (640, 1333), | ||
(672, 1333), (704, 1333), (736, 1333), | ||
(768, 1333), (800, 1333)], | ||
keep_ratio=True) | ||
]]), | ||
dict(type='PackDetInputs') | ||
] | ||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) | ||
|
||
# optimizer | ||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001), | ||
clip_grad=dict(max_norm=0.1, norm_type=2), | ||
paramwise_cfg=dict( | ||
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) | ||
|
||
# learning policy | ||
max_epochs = 50 | ||
train_cfg = dict( | ||
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=max_epochs, | ||
by_epoch=True, | ||
milestones=[40], | ||
gamma=0.1) | ||
] | ||
|
||
# NOTE: `auto_scale_lr` is for automatically scaling LR, | ||
# USER SHOULD NOT CHANGE ITS VALUES. | ||
# base_batch_size = (8 GPUs) x (2 samples per GPU) | ||
auto_scale_lr = dict(base_batch_size=16, enable=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import Tuple | ||
|
||
import torch.nn as nn | ||
from mmcv.cnn import Linear | ||
from mmengine.model import bias_init_with_prob, constant_init | ||
from torch import Tensor | ||
|
||
from mmdet.registry import MODELS | ||
from mmdet.structures import SampleList | ||
from mmdet.utils import InstanceList | ||
from ..layers import MLP, inverse_sigmoid | ||
from .conditional_detr_head import ConditionalDETRHead | ||
|
||
|
||
@MODELS.register_module() | ||
class DABDETRHead(ConditionalDETRHead): | ||
"""Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for | ||
DETR. | ||
More details can be found in the `paper | ||
<https://arxiv.org/abs/2201.12329>`_ . | ||
""" | ||
|
||
def _init_layers(self) -> None: | ||
"""Initialize layers of the transformer head.""" | ||
# cls branch | ||
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels) | ||
# reg branch | ||
self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3) | ||
|
||
def init_weights(self) -> None: | ||
"""initialize weights.""" | ||
if self.loss_cls.use_sigmoid: | ||
bias_init = bias_init_with_prob(0.01) | ||
nn.init.constant_(self.fc_cls.bias, bias_init) | ||
constant_init(self.fc_reg.layers[-1], 0., bias=0.) | ||
|
||
def forward(self, hidden_states: Tensor, | ||
references: Tensor) -> Tuple[Tensor, Tensor]: | ||
""""Forward function. | ||
Args: | ||
hidden_states (Tensor): Features from transformer decoder. If | ||
`return_intermediate_dec` is True output has shape | ||
(num_decoder_layers, bs, num_queries, dim), else has shape (1, | ||
bs, num_queries, dim) which only contains the last layer | ||
outputs. | ||
references (Tensor): References from transformer decoder. If | ||
`return_intermediate_dec` is True output has shape | ||
(num_decoder_layers, bs, num_queries, 2/4), else has shape (1, | ||
bs, num_queries, 2/4) | ||
which only contains the last layer reference. | ||
Returns: | ||
tuple[Tensor]: results of head containing the following tensor. | ||
- layers_cls_scores (Tensor): Outputs from the classification head, | ||
shape (num_decoder_layers, bs, num_queries, cls_out_channels). | ||
Note cls_out_channels should include background. | ||
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression | ||
head with normalized coordinate format (cx, cy, w, h), has shape | ||
(num_decoder_layers, bs, num_queries, 4). | ||
""" | ||
layers_cls_scores = self.fc_cls(hidden_states) | ||
references_before_sigmoid = inverse_sigmoid(references, eps=1e-3) | ||
tmp_reg_preds = self.fc_reg(hidden_states) | ||
tmp_reg_preds[..., :references_before_sigmoid. | ||
size(-1)] += references_before_sigmoid | ||
layers_bbox_preds = tmp_reg_preds.sigmoid() | ||
return layers_cls_scores, layers_bbox_preds | ||
|
||
def predict(self, | ||
hidden_states: Tensor, | ||
references: Tensor, | ||
batch_data_samples: SampleList, | ||
rescale: bool = True) -> InstanceList: | ||
"""Perform forward propagation of the detection head and predict | ||
detection results on the features of the upstream network. Over-write | ||
because img_metas are needed as inputs for bbox_head. | ||
Args: | ||
hidden_states (Tensor): Feature from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, dim). | ||
references (Tensor): references from the transformer decoder, has | ||
shape (num_decoder_layers, bs, num_queries, 2/4). | ||
batch_data_samples (List[:obj:`DetDataSample`]): The Data | ||
Samples. It usually includes information such as | ||
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`. | ||
rescale (bool, optional): Whether to rescale the results. | ||
Defaults to True. | ||
Returns: | ||
list[obj:`InstanceData`]: Detection results of each image | ||
after the post process. | ||
""" | ||
batch_img_metas = [ | ||
data_samples.metainfo for data_samples in batch_data_samples | ||
] | ||
|
||
last_layer_hidden_state = hidden_states[-1].unsqueeze(0) | ||
last_layer_reference = references[-1].unsqueeze(0) | ||
outs = self(last_layer_hidden_state, last_layer_reference) | ||
|
||
predictions = self.predict_by_feat( | ||
*outs, batch_img_metas=batch_img_metas, rescale=rescale) | ||
return predictions |
Oops, something went wrong.