From 0932ab787d58eead15b5f823fbcca5351ceb90f7 Mon Sep 17 00:00:00 2001 From: Cedric Luo Date: Fri, 25 Mar 2022 17:13:55 +0800 Subject: [PATCH] add msdeformattn pixel decoder (#7466) fix typo rm img_metas rename in pixel_decoder update comments rename fix typo generae points with MlvlPointGenerator --- mmdet/models/plugins/__init__.py | 6 +- .../plugins/msdeformattn_pixel_decoder.py | 269 ++++++++++++++++++ mmdet/models/plugins/pixel_decoder.py | 20 +- tests/test_models/test_plugins.py | 60 +++- 4 files changed, 342 insertions(+), 13 deletions(-) create mode 100644 mmdet/models/plugins/msdeformattn_pixel_decoder.py diff --git a/mmdet/models/plugins/__init__.py b/mmdet/models/plugins/__init__.py index 940d94e884a..a455c07bb99 100644 --- a/mmdet/models/plugins/__init__.py +++ b/mmdet/models/plugins/__init__.py @@ -1,5 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dropblock import DropBlock +from .msdeformattn_pixel_decoder import MSDeformAttnPixelDecoder from .pixel_decoder import PixelDecoder, TransformerEncoderPixelDecoder -__all__ = ['DropBlock', 'PixelDecoder', 'TransformerEncoderPixelDecoder'] +__all__ = [ + 'DropBlock', 'PixelDecoder', 'TransformerEncoderPixelDecoder', + 'MSDeformAttnPixelDecoder' +] diff --git a/mmdet/models/plugins/msdeformattn_pixel_decoder.py b/mmdet/models/plugins/msdeformattn_pixel_decoder.py new file mode 100644 index 00000000000..d553582baef --- /dev/null +++ b/mmdet/models/plugins/msdeformattn_pixel_decoder.py @@ -0,0 +1,269 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import (PLUGIN_LAYERS, Conv2d, ConvModule, caffe2_xavier_init, + normal_init, xavier_init) +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer_sequence) +from mmcv.runner import BaseModule, ModuleList + +from mmdet.core.anchor import MlvlPointGenerator +from mmdet.models.utils.transformer import MultiScaleDeformableAttention + + +@PLUGIN_LAYERS.register_module() +class MSDeformAttnPixelDecoder(BaseModule): + """Pixel decoder with multi-scale deformable attention. + + Args: + in_channels (list[int] | tuple[int]): Number of channels in the + input feature maps. + strides (list[int] | tuple[int]): Output strides of feature from + backbone. + feat_channels (int): Number of channels for feature. + out_channels (int): Number of channels for output. + num_outs (int): Number of output scales. + norm_cfg (:obj:`mmcv.ConfigDict` | dict): Config for normalization. + Defaults to dict(type='GN', num_groups=32). + act_cfg (:obj:`mmcv.ConfigDict` | dict): Config for activation. + Defaults to dict(type='ReLU'). + encoder (:obj:`mmcv.ConfigDict` | dict): Config for transformer + encoder. Defaults to `DetrTransformerEncoder`. + positional_encoding (:obj:`mmcv.ConfigDict` | dict): Config for + transformer encoder position encoding. Defaults to + dict(type='SinePositionalEncoding', num_feats=128, + normalize=True). + init_cfg (:obj:`mmcv.ConfigDict` | dict): Initialization config dict. + """ + + def __init__(self, + in_channels=[256, 512, 1024, 2048], + strides=[4, 8, 16, 32], + feat_channels=256, + out_channels=256, + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + feedforward_channels=1024, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.strides = strides + self.num_input_levels = len(in_channels) + self.num_encoder_levels = \ + encoder.transformerlayers.attn_cfgs.num_levels + assert self.num_encoder_levels >= 1, \ + 'num_levels in attn_cfgs must be at least one' + input_conv_list = [] + # from top to down (low to high resolution) + for i in range(self.num_input_levels - 1, + self.num_input_levels - self.num_encoder_levels - 1, + -1): + input_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + norm_cfg=norm_cfg, + act_cfg=None, + bias=True) + input_conv_list.append(input_conv) + self.input_convs = ModuleList(input_conv_list) + + self.encoder = build_transformer_layer_sequence(encoder) + self.postional_encoding = build_positional_encoding( + positional_encoding) + # high resolution to low resolution + self.level_encoding = nn.Embedding(self.num_encoder_levels, + feat_channels) + + # fpn-like structure + self.lateral_convs = ModuleList() + self.output_convs = ModuleList() + self.use_bias = norm_cfg is None + # from top to down (low to high resolution) + # fpn for the rest features that didn't pass in encoder + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + lateral_conv = ConvModule( + in_channels[i], + feat_channels, + kernel_size=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=None) + output_conv = ConvModule( + feat_channels, + feat_channels, + kernel_size=3, + stride=1, + padding=1, + bias=self.use_bias, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) + + self.mask_feature = Conv2d( + feat_channels, out_channels, kernel_size=1, stride=1, padding=0) + + self.num_outs = num_outs + self.point_generator = MlvlPointGenerator(strides) + + def init_weights(self): + """Initialize weights.""" + for i in range(0, self.num_encoder_levels): + xavier_init( + self.input_convs[i].conv, + gain=1, + bias=0, + distribution='uniform') + + for i in range(0, self.num_input_levels - self.num_encoder_levels): + caffe2_xavier_init(self.lateral_convs[i].conv, bias=0) + caffe2_xavier_init(self.output_convs[i].conv, bias=0) + + caffe2_xavier_init(self.mask_feature, bias=0) + + normal_init(self.level_encoding, mean=0, std=1) + for p in self.encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_normal_(p) + + # init_weights defined in MultiScaleDeformableAttention + for layer in self.encoder.layers: + for attn in layer.attentions: + if isinstance(attn, MultiScaleDeformableAttention): + attn.init_weights() + + def forward(self, feats): + """ + Args: + feats (list[Tensor]): Feature maps of each level. Each has + shape of (batch_size, c, h, w). + + Returns: + tuple: A tuple containing the following: + + - mask_feature (Tensor): shape (batch_size, c, h, w). + - multi_scale_features (list[Tensor]): Multi scale \ + features, each in shape (batch_size, c, h, w). + """ + # generate padding mask for each level, for each image + batch_size = feats[0].shape[0] + encoder_input_list = [] + padding_mask_list = [] + level_positional_encoding_list = [] + spatial_shapes = [] + reference_points_list = [] + for i in range(self.num_encoder_levels): + level_idx = self.num_input_levels - i - 1 + feat = feats[level_idx] + feat_projected = self.input_convs[i](feat) + h, w = feat.shape[-2:] + + # no padding + padding_mask_resized = feat.new_zeros( + (batch_size, ) + feat.shape[-2:], dtype=torch.bool) + pos_embed = self.postional_encoding(padding_mask_resized) + level_embed = self.level_encoding.weight[i] + level_pos_embed = level_embed.view(1, -1, 1, 1) + pos_embed + # (h_i * w_i, 2) + reference_points = self.point_generator.single_level_grid_priors( + feat.shape[-2:], level_idx, device=feat.device) + # normalize + factor = feat.new_tensor([[w, h]]) * self.strides[level_idx] + reference_points = reference_points / factor + + # shape (batch_size, c, h_i, w_i) -> (h_i * w_i, batch_size, c) + feat_projected = feat_projected.flatten(2).permute(2, 0, 1) + level_pos_embed = level_pos_embed.flatten(2).permute(2, 0, 1) + padding_mask_resized = padding_mask_resized.flatten(1) + + encoder_input_list.append(feat_projected) + padding_mask_list.append(padding_mask_resized) + level_positional_encoding_list.append(level_pos_embed) + spatial_shapes.append(feat.shape[-2:]) + reference_points_list.append(reference_points) + # shape (batch_size, total_num_query), + # total_num_query=sum([., h_i * w_i,.]) + padding_masks = torch.cat(padding_mask_list, dim=1) + # shape (total_num_query, batch_size, c) + encoder_inputs = torch.cat(encoder_input_list, dim=0) + level_positional_encodings = torch.cat( + level_positional_encoding_list, dim=0) + device = encoder_inputs.device + # shape (num_encoder_levels, 2), from low + # resolution to high resolution + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=device) + # shape (0, h_0*w_0, h_0*w_0+h_1*w_1, ...) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + reference_points = torch.cat(reference_points_list, dim=0) + reference_points = reference_points[None, :, None].repeat( + batch_size, 1, self.num_encoder_levels, 1) + valid_radios = reference_points.new_ones( + (batch_size, self.num_encoder_levels, 2)) + # shape (num_total_query, batch_size, c) + memory = self.encoder( + query=encoder_inputs, + key=None, + value=None, + query_pos=level_positional_encodings, + key_pos=None, + attn_masks=None, + key_padding_mask=None, + query_key_padding_mask=padding_masks, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_radios=valid_radios) + # (num_total_query, batch_size, c) -> (batch_size, c, num_total_query) + memory = memory.permute(1, 2, 0) + + # from low resolution to high resolution + num_query_per_level = [e[0] * e[1] for e in spatial_shapes] + outs = torch.split(memory, num_query_per_level, dim=-1) + outs = [ + x.reshape(batch_size, -1, spatial_shapes[i][0], + spatial_shapes[i][1]) for i, x in enumerate(outs) + ] + + for i in range(self.num_input_levels - self.num_encoder_levels - 1, -1, + -1): + x = feats[i] + cur_feat = self.lateral_convs[i](x) + y = cur_feat + F.interpolate( + outs[-1], + size=cur_feat.shape[-2:], + mode='bilinear', + align_corners=False) + y = self.output_convs[i](y) + outs.append(y) + multi_scale_features = outs[:self.num_outs] + + mask_feature = self.mask_feature(outs[-1]) + return mask_feature, multi_scale_features diff --git a/mmdet/models/plugins/pixel_decoder.py b/mmdet/models/plugins/pixel_decoder.py index d1193551ddd..537a187dc5c 100644 --- a/mmdet/models/plugins/pixel_decoder.py +++ b/mmdet/models/plugins/pixel_decoder.py @@ -45,14 +45,14 @@ def __init__(self, self.output_convs = ModuleList() self.use_bias = norm_cfg is None for i in range(0, self.num_inputs - 1): - l_conv = ConvModule( + lateral_conv = ConvModule( in_channels[i], feat_channels, kernel_size=1, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=None) - o_conv = ConvModule( + output_conv = ConvModule( feat_channels, feat_channels, kernel_size=3, @@ -61,8 +61,8 @@ def __init__(self, bias=self.use_bias, norm_cfg=norm_cfg, act_cfg=act_cfg) - self.lateral_convs.append(l_conv) - self.output_convs.append(o_conv) + self.lateral_convs.append(lateral_conv) + self.output_convs.append(output_conv) self.last_feat_conv = ConvModule( in_channels[-1], @@ -102,9 +102,9 @@ def forward(self, feats, img_metas): y = self.last_feat_conv(feats[-1]) for i in range(self.num_inputs - 2, -1, -1): x = feats[i] - cur_fpn = self.lateral_convs[i](x) - y = cur_fpn + \ - F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + cur_feat = self.lateral_convs[i](x) + y = cur_feat + \ + F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') y = self.output_convs[i](y) mask_feature = self.mask_feature(y) @@ -234,9 +234,9 @@ def forward(self, feats, img_metas): y = self.encoder_out_proj(memory) for i in range(self.num_inputs - 2, -1, -1): x = feats[i] - cur_fpn = self.lateral_convs[i](x) - y = cur_fpn + \ - F.interpolate(y, size=cur_fpn.shape[-2:], mode='nearest') + cur_feat = self.lateral_convs[i](x) + y = cur_feat + \ + F.interpolate(y, size=cur_feat.shape[-2:], mode='nearest') y = self.output_convs[i](y) mask_feature = self.mask_feature(y) diff --git a/tests/test_models/test_plugins.py b/tests/test_models/test_plugins.py index b115fbd73f2..8afd1f9403a 100644 --- a/tests/test_models/test_plugins.py +++ b/tests/test_models/test_plugins.py @@ -31,7 +31,7 @@ def test_dropblock(): DropBlock(0.5, 3, -1) -def test_pixeldecoder(): +def test_pixel_decoder(): base_channels = 64 pixel_decoder_cfg = ConfigDict( dict( @@ -53,7 +53,7 @@ def test_pixeldecoder(): assert mask_feature.shape == feats[0].shape -def test_transformerencoderpixeldecoer(): +def test_transformer_encoder_pixel_decoder(): base_channels = 64 pixel_decoder_cfg = ConfigDict( dict( @@ -109,3 +109,59 @@ def test_transformerencoderpixeldecoer(): assert memory.shape[-2:] == feats[-1].shape[-2:] assert mask_feature.shape == feats[0].shape + + +def test_msdeformattn_pixel_decoder(): + base_channels = 64 + pixel_decoder_cfg = ConfigDict( + dict( + type='MSDeformAttnPixelDecoder', + in_channels=[base_channels * 2**i for i in range(4)], + strides=[4, 8, 16, 32], + feat_channels=base_channels, + out_channels=base_channels, + num_outs=3, + norm_cfg=dict(type='GN', num_groups=32), + act_cfg=dict(type='ReLU'), + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=base_channels, + num_heads=8, + num_levels=3, + num_points=4, + im2col_step=64, + dropout=0.0, + batch_first=False, + norm_cfg=None, + init_cfg=None), + ffn_cfgs=dict( + type='FFN', + embed_dims=base_channels, + feedforward_channels=base_channels * 4, + num_fcs=2, + ffn_drop=0.0, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'ffn', 'norm')), + init_cfg=None), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=base_channels // 2, + normalize=True), + init_cfg=None), ) + self = build_plugin_layer(pixel_decoder_cfg)[1] + feats = [ + torch.rand((2, base_channels * 2**i, 4 * 2**(3 - i), 5 * 2**(3 - i))) + for i in range(4) + ] + mask_feature, multi_scale_features = self(feats) + + assert mask_feature.shape == feats[0].shape + assert len(multi_scale_features) == 3 + multi_scale_features = multi_scale_features[::-1] + for i in range(3): + assert multi_scale_features[i].shape[-2:] == feats[i + 1].shape[-2:]