Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add group detr for dino #7865

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions configs/group_detr/_base_/dino_2000_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [400, 500, 600, 700, 800, 900] },
RandomSizeCrop: { min_size: 384, max_size: 900 },
RandomShortSideResize: { short_side_sizes: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992, 1024, 1056, 1088, 1120, 1152, 1184], max_size: 2000 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false


TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [1184, 2000], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
48 changes: 48 additions & 0 deletions configs/group_detr/_base_/dino_reader.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
worker_num: 2
TrainReader:
sample_transforms:
- Decode: {}
- RandomFlip: {prob: 0.5}
- RandomSelect: { transforms1: [ RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ],
transforms2: [
RandomShortSideResize: { short_side_sizes: [ 400, 500, 600 ] },
RandomSizeCrop: { min_size: 384, max_size: 600 },
RandomShortSideResize: { short_side_sizes: [ 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800 ], max_size: 1333 } ]
}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- NormalizeBox: {}
- BboxXYXY2XYWH: {}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 2
shuffle: true
drop_last: true
collate_batch: false
use_shared_memory: false


EvalReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false


TestReader:
sample_transforms:
- Decode: {}
- Resize: {target_size: [800, 1333], keep_ratio: True}
- NormalizeImage: {is_scale: true, mean: [0.485,0.456,0.406], std: [0.229, 0.224,0.225]}
- Permute: {}
batch_transforms:
- PadMaskBatch: {pad_to_stride: -1, return_pad_mask: true}
batch_size: 1
shuffle: false
drop_last: false
53 changes: 53 additions & 0 deletions configs/group_detr/_base_/group_dino_r50.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
architecture: DETR
pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_cos_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True


DETR:
backbone: ResNet
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess

ResNet:
# index 0 stands for res2
depth: 50
norm_type: bn
freeze_at: 0
return_idx: [1, 2, 3]
lr_mult_list: [0.0, 0.1, 0.1, 0.1]
num_stages: 4

GroupDINOTransformer:
num_queries: 900
position_embed_type: sine
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
dropout: 0.0
activation: relu
pe_temperature: 20
pe_offset: 0.0
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
dual_queries: True
dual_groups: 10

DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}

DETRBBoxPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
68 changes: 68 additions & 0 deletions configs/group_detr/_base_/group_dino_vit_huge.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
architecture: DETR
pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/vit_huge_mae_patch14_dec512d8b_pretrained.pdparams
hidden_dim: 256
use_focal_loss: True

DETR:
backbone: VisionTransformer2D
neck: SimpleFeaturePyramid
transformer: GroupDINOTransformer
detr_head: DINOHead
post_process: DETRBBoxPostProcess

VisionTransformer2D:
patch_size: 16
embed_dim: 1280
depth: 32
num_heads: 16
mlp_ratio: 4
attn_bias: True
drop_rate: 0.0
drop_path_rate: 0.1
lr_decay_rate: 0.7
global_attn_indexes: [7, 15, 23, 31]
use_abs_pos: False
use_rel_pos: True
rel_pos_zero_init: True
window_size: 14
out_indices: [ 31, ]

SimpleFeaturePyramid:
out_channels: 256
num_levels: 4

GroupDINOTransformer:
num_queries: 900
position_embed_type: sine
pe_temperature: 20
pe_offset: 0.0
num_levels: 4
nhead: 8
num_encoder_layers: 6
num_decoder_layers: 6
dim_feedforward: 2048
use_input_proj: False
dropout: 0.0
activation: relu
num_denoising: 100
label_noise_ratio: 0.5
box_noise_scale: 1.0
learnt_init_query: True
dual_queries: True
dual_groups: 10


DINOHead:
loss:
name: DINOLoss
loss_coeff: {class: 1, bbox: 5, giou: 2}
aux_loss: True
matcher:
name: HungarianMatcher
matcher_coeff: {class: 2, bbox: 5, giou: 2}


DETRBBoxPostProcess:
num_top_queries: 300
dual_queries: True
dual_groups: 10
16 changes: 16 additions & 0 deletions configs/group_detr/_base_/optimizer_1x.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
epoch: 12

LearningRate:
base_lr: 0.0001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones: [11]
use_warmup: false

OptimizerBuilder:
clip_grad_by_norm: 0.1
regularizer: false
optimizer:
type: AdamW
weight_decay: 0.0001
11 changes: 11 additions & 0 deletions configs/group_detr/group_dino_r50_4scale_1x_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/group_dino_r50.yml',
'_base_/dino_reader.yml',
]

weights: output/group_dino_r50_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
11 changes: 11 additions & 0 deletions configs/group_detr/group_dino_vit_huge_4scale_1x_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
_BASE_: [
'../datasets/coco_detection.yml',
'../runtime.yml',
'_base_/optimizer_1x.yml',
'_base_/group_dino_vit_huge.yml',
'_base_/dino_2000_reader.yml',
]

weights: output/group_dino_vit_huge_4scale_1x_coco/model_final
find_unused_parameters: True
log_iter: 100
13 changes: 12 additions & 1 deletion ppdet/modeling/architectures/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def __init__(self,
backbone,
transformer='DETRTransformer',
detr_head='DETRHead',
neck=None,
post_process='DETRBBoxPostProcess',
exclude_post_process=False):
super(DETR, self).__init__()
self.backbone = backbone
self.neck = neck
self.transformer = transformer
self.detr_head = detr_head
self.post_process = post_process
Expand All @@ -47,8 +49,12 @@ def __init__(self,
def from_config(cls, cfg, *args, **kwargs):
# backbone
backbone = create(cfg['backbone'])
# transformer
# neck
kwargs = {'input_shape': backbone.out_shape}
neck = create(cfg['neck'], **kwargs) if cfg['neck'] else None
# transformer
if neck is not None:
kwargs = {'input_shape': neck.out_shape}
transformer = create(cfg['transformer'], **kwargs)
# head
kwargs = {
Expand All @@ -62,12 +68,17 @@ def from_config(cls, cfg, *args, **kwargs):
'backbone': backbone,
'transformer': transformer,
"detr_head": detr_head,
"neck": neck
}

def _forward(self):
# Backbone
body_feats = self.backbone(self.inputs)

# Neck
if self.neck is not None:
body_feats = self.neck(body_feats)

# Transformer
pad_mask = self.inputs.get('pad_mask', None)
out_transformer = self.transformer(body_feats, pad_mask, self.inputs)
Expand Down
2 changes: 2 additions & 0 deletions ppdet/modeling/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import mobileone
from . import trans_encoder
from . import focalnet
from . import vit_mae

from .vgg import *
from .resnet import *
Expand All @@ -61,3 +62,4 @@
from .mobileone import *
from .trans_encoder import *
from .focalnet import *
from .vit_mae import *
50 changes: 50 additions & 0 deletions ppdet/modeling/backbones/transformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import paddle
import paddle.nn as nn
import paddle.nn.functional as F

from paddle.nn.initializer import TruncatedNormal, Constant, Assign

Expand Down Expand Up @@ -72,3 +73,52 @@ def add_parameter(layer, datas, name=None):
if name:
layer.add_parameter(name, parameter)
return parameter


def window_partition(x, window_size):
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, H, W, C = paddle.shape(x)

pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
x = F.pad(x.transpose([0, 3, 1, 2]),
paddle.to_tensor(
[0, int(pad_w), 0, int(pad_h)],
dtype='int32')).transpose([0, 2, 3, 1])
Hp, Wp = H + pad_h, W + pad_w

num_h, num_w = Hp // window_size, Wp // window_size

x = x.reshape([B, num_h, window_size, num_w, window_size, C])
windows = x.transpose([0, 1, 3, 2, 4, 5]).reshape(
[-1, window_size, window_size, C])
return windows, (Hp, Wp), (num_h, num_w)


def window_unpartition(x, pad_hw, num_hw, hw):
"""
Window unpartition into original sequences and removing padding.
Args:
x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Hp, Wp = pad_hw
num_h, num_w = num_hw
H, W = hw
B, window_size, _, C = paddle.shape(x)
B = B // (num_h * num_w)
x = x.reshape([B, num_h, num_w, window_size, window_size, C])
x = x.transpose([0, 1, 3, 2, 4, 5]).reshape([B, Hp, Wp, C])

return x[:, :H, :W, :]
Loading