Skip to content

Commit

Permalink
[Refactor]: Refactor DAB-DETR in MMDetection 3.x (open-mmlab#9252)
Browse files Browse the repository at this point in the history
* resolve refactor conflict w.o. pre-commit hooks

* fixed error and finished alignment

* supprot 91 cls and remove batch_first

* delete iter_update, keep return intermediate

* delete posHW

* substitute 'num_query' to 'num_queries'

* change 'gen_sineembed_for_position' to 'convert_coordinate_to_encoding'

* resolve extra comments

* fix error

* fix error

* fix data path

* support 91 cls temporarily

* resolve extra comments

* fix num_keys, num_feats

* delete reg_branches in decoder_inputs_dict

* fix docstring

* fix docstring

* commit modification in pr of DINO

* fix data format from nbc to bnc in detr and deformable-detr

* fix 'gen_encoder_output_proposals' for two-stage deformable-detr

* fix 'gen_encoder_output_proposals' for two-stage deformable-detr

* set 'batch_first' to True in deformable attention

* fix error

* fix ut

* add assert for batch_first

* remove 91 cls

* modify pre_decoder of DeformableDETR

* delete useless comments

* bnc data flow w.o. merge detr and def-detr

* assert batch first flag in conditional attention, fix error

* add unit test for dab-detr

* fix doc

* disable yapf hook

* move conditional attention to trm/layers

* fix name and add doc

* fix doc

* add loss_and_predict for head

* fix doc and typehint

* fix doc and typehint

* modify batch first assert for attention

* change Dab to DAB

* rename file and function

* make dab-detr head inherit conditional detr head

* fix doc

* fix doc

Co-authored-by: QingyunLi <962537281@qq.com>
  • Loading branch information
2 people authored and yumion committed Jan 31, 2024
1 parent 80b7aca commit e5afc49
Show file tree
Hide file tree
Showing 17 changed files with 1,124 additions and 301 deletions.
40 changes: 40 additions & 0 deletions configs/dab_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# DAB-DETR

> [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329)
<!-- [ALGORITHM] -->

## Abstract

We present in this paper a novel query formulation using dynamic anchor boxes for DETR (DEtection TRansformer) and offer a deeper understanding of the role of queries in DETR. This new formulation directly uses box coordinates as queries in Transformer decoders and dynamically updates them layer-by-layer. Using box coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR, but also allows us to modulate the positional attention map using the box width and height information. Such a design makes it clear that queries in DETR can be implemented as performing soft ROI pooling layer-by-layer in a cascade manner. As a result, it leads to the best performance on MS-COCO benchmark among the DETR-like detection models under the same setting, e.g., AP 45.7% using ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive experiments to confirm our analysis and verify the effectiveness of our methods.

<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/arch.png?raw=true"/>
</div>
<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/model.png?raw=true"/>
</div>
<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/results.png?raw=true"/>
</div>

## Results and Models

We provide the config files and models for DAB-DETR: [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329).

| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :------: | :-----: | :------: | :------------: | :----: | :---------------------------------------: | :----------------------------------: |
| R-50 | DAB-DETR | 50e | 6.4 | | 42.3 | [config](./dab-detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) |

## Citation

```latex
@inproceedings{
liu2022dabdetr,
title={{DAB}-{DETR}: Dynamic Anchor Boxes are Better Queries for {DETR}},
author={Shilong Liu and Feng Li and Hao Zhang and Xiao Yang and Xianbiao Qi and Hang Su and Jun Zhu and Lei Zhang},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=oMI9PjOb9Jl}
}
```
162 changes: 162 additions & 0 deletions configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DABDETR',
num_queries=300,
with_random_refpoints=False,
num_patterns=0,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=None,
num_outs=1),
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0., batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU')))),
decoder=dict(
num_layers=6,
query_dim=4,
query_scale_type='cond_elewise',
with_modulated_hw_attn=True,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=False),
cross_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU'))),
return_intermediate=True),
positional_encoding_cfg=dict(
num_feats=128, temperature=20, normalize=True),
bbox_head=dict(
type='DABDETRHead',
num_classes=80,
embed_dims=256,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2., eps=1e-8),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
scales=[(400, 1333), (500, 1333), (600, 1333)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
keep_ratio=True)
]]),
dict(type='PackDetInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))

# learning policy
max_epochs = 50
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[40],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16, enable=False)
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .condinst_head import CondInstBboxHead, CondInstMaskHead
from .conditional_detr_head import ConditionalDETRHead
from .corner_head import CornerHead
from .dab_detr_head import DABDETRHead
from .ddod_head import DDODHead
from .deformable_detr_head import DeformableDETRHead
from .detr_head import DETRHead
Expand Down Expand Up @@ -63,5 +64,6 @@
'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead',
'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead',
'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead',
'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead'
'BoxInstBboxHead', 'BoxInstMaskHead', 'ConditionalDETRHead', 'DINOHead',
'DABDETRHead'
]
20 changes: 10 additions & 10 deletions mmdet/models/dense_heads/conditional_detr_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def forward(self, hidden_states: Tensor,
Args:
hidden_states (Tensor): Features from transformer decoder. If
`return_intermediate_dec` in detr.py is True output has shape
`return_intermediate_dec` is True output has shape
(num_decoder_layers, bs, num_queries, dim), else has shape (1,
bs, num_queries, dim) which only contains the last layer
outputs.
references (Tensor): References from transformer decoder,has
shape (bs, num_query, 2).
references (Tensor): References from transformer decoder, has
shape (bs, num_queries, 2).
Returns:
tuple[Tensor]: results of head containing the following tensor.
Expand Down Expand Up @@ -72,9 +72,9 @@ def loss(self, hidden_states: Tensor, references: Tensor,
head on the features of the upstream network.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
hidden_states (Tensor): Features from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
references (Tensor): References from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
Expand Down Expand Up @@ -102,16 +102,16 @@ def loss_and_predict(
img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
hidden_states (Tensor): Features from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
references (Tensor): References from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (list[:obj:`DetDataSample`]): Each item contains
the meta information of each image and corresponding
annotations.
Returns:
tuple: the return value is a tuple contains:
tuple: The return value is a tuple contains:
- losses: (dict[str, Tensor]): A dictionary of loss components.
- predictions (list[:obj:`InstanceData`]): Detection
Expand Down Expand Up @@ -141,9 +141,9 @@ def predict(self,
because img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
hidden_states (Tensor): Features from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
references (Tensor): References from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
Expand Down
106 changes: 106 additions & 0 deletions mmdet/models/dense_heads/dab_detr_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch.nn as nn
from mmcv.cnn import Linear
from mmengine.model import bias_init_with_prob, constant_init
from torch import Tensor

from mmdet.registry import MODELS
from mmdet.structures import SampleList
from mmdet.utils import InstanceList
from ..layers import MLP, inverse_sigmoid
from .conditional_detr_head import ConditionalDETRHead


@MODELS.register_module()
class DABDETRHead(ConditionalDETRHead):
"""Head of DAB-DETR. DAB-DETR: Dynamic Anchor Boxes are Better Queries for
DETR.
More details can be found in the `paper
<https://arxiv.org/abs/2201.12329>`_ .
"""

def _init_layers(self) -> None:
"""Initialize layers of the transformer head."""
# cls branch
self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
# reg branch
self.fc_reg = MLP(self.embed_dims, self.embed_dims, 4, 3)

def init_weights(self) -> None:
"""initialize weights."""
if self.loss_cls.use_sigmoid:
bias_init = bias_init_with_prob(0.01)
nn.init.constant_(self.fc_cls.bias, bias_init)
constant_init(self.fc_reg.layers[-1], 0., bias=0.)

def forward(self, hidden_states: Tensor,
references: Tensor) -> Tuple[Tensor, Tensor]:
""""Forward function.
Args:
hidden_states (Tensor): Features from transformer decoder. If
`return_intermediate_dec` is True output has shape
(num_decoder_layers, bs, num_queries, dim), else has shape (1,
bs, num_queries, dim) which only contains the last layer
outputs.
references (Tensor): References from transformer decoder. If
`return_intermediate_dec` is True output has shape
(num_decoder_layers, bs, num_queries, 2/4), else has shape (1,
bs, num_queries, 2/4)
which only contains the last layer reference.
Returns:
tuple[Tensor]: results of head containing the following tensor.
- layers_cls_scores (Tensor): Outputs from the classification head,
shape (num_decoder_layers, bs, num_queries, cls_out_channels).
Note cls_out_channels should include background.
- layers_bbox_preds (Tensor): Sigmoid outputs from the regression
head with normalized coordinate format (cx, cy, w, h), has shape
(num_decoder_layers, bs, num_queries, 4).
"""
layers_cls_scores = self.fc_cls(hidden_states)
references_before_sigmoid = inverse_sigmoid(references, eps=1e-3)
tmp_reg_preds = self.fc_reg(hidden_states)
tmp_reg_preds[..., :references_before_sigmoid.
size(-1)] += references_before_sigmoid
layers_bbox_preds = tmp_reg_preds.sigmoid()
return layers_cls_scores, layers_bbox_preds

def predict(self,
hidden_states: Tensor,
references: Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> InstanceList:
"""Perform forward propagation of the detection head and predict
detection results on the features of the upstream network. Over-write
because img_metas are needed as inputs for bbox_head.
Args:
hidden_states (Tensor): Feature from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, dim).
references (Tensor): references from the transformer decoder, has
shape (num_decoder_layers, bs, num_queries, 2/4).
batch_data_samples (List[:obj:`DetDataSample`]): The Data
Samples. It usually includes information such as
`gt_instance`, `gt_panoptic_seg` and `gt_sem_seg`.
rescale (bool, optional): Whether to rescale the results.
Defaults to True.
Returns:
list[obj:`InstanceData`]: Detection results of each image
after the post process.
"""
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
]

last_layer_hidden_state = hidden_states[-1].unsqueeze(0)
last_layer_reference = references[-1].unsqueeze(0)
outs = self(last_layer_hidden_state, last_layer_reference)

predictions = self.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
return predictions
Loading

0 comments on commit e5afc49

Please sign in to comment.