Skip to content

Commit

Permalink
Point rend
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo authored and ZwwWayne committed Jul 19, 2022
1 parent 727d867 commit 16e47a4
Show file tree
Hide file tree
Showing 14 changed files with 705 additions and 461 deletions.
11 changes: 6 additions & 5 deletions configs/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain_1x_coco.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
_base_ = './mask_rcnn_r50_fpn_1x_coco.py'
preprocess_cfg = dict(
data_preprocessor = dict(
mean=[103.530, 116.280, 123.675],
std=[1.0, 1.0, 1.0],
to_rgb=False,
bgr_to_rgb=False,
pad_size_divisor=32)
model = dict(
# use caffe img_norm
preprocess_cfg=preprocess_cfg,
data_preprocessor=data_preprocessor,
backbone=dict(
norm_cfg=dict(requires_grad=False),
style='caffe',
Expand All @@ -18,8 +18,9 @@
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='RandomChoiceResize',
img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)]),
scales=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
(1333, 768), (1333, 800)],
keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs'),
]
Expand Down
18 changes: 16 additions & 2 deletions configs/point_rend/point_rend_r50_caffe_fpn_mstrain_3x_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,18 @@
_base_ = './point_rend_r50_caffe_fpn_mstrain_1x_coco.py'

max_epochs = 36

# learning policy
lr_config = dict(step=[28, 34])
runner = dict(type='EpochBasedRunner', max_epochs=36)
param_scheduler = [
dict(
type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, end=500),
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[28, 34],
gamma=0.1)
]

train_cfg = dict(max_epochs=max_epochs)
25 changes: 14 additions & 11 deletions mmdet/models/detectors/point_rend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.config import ConfigDict

from mmdet.core.utils import OptConfigType, OptMultiConfig
from mmdet.registry import MODELS
from .two_stage import TwoStageDetector

Expand All @@ -13,20 +16,20 @@ class PointRend(TwoStageDetector):
"""

def __init__(self,
backbone,
rpn_head,
roi_head,
train_cfg,
test_cfg,
neck=None,
pretrained=None,
init_cfg=None):
super(PointRend, self).__init__(
backbone: ConfigDict,
rpn_head: ConfigDict,
roi_head: ConfigDict,
train_cfg: ConfigDict,
test_cfg: ConfigDict,
neck: OptConfigType = None,
data_preprocessor: OptConfigType = None,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
backbone=backbone,
neck=neck,
rpn_head=rpn_head,
roi_head=roi_head,
train_cfg=train_cfg,
test_cfg=test_cfg,
pretrained=pretrained,
init_cfg=init_cfg)
init_cfg=init_cfg,
data_preprocessor=data_preprocessor)
40 changes: 25 additions & 15 deletions mmdet/models/roi_heads/mask_heads/coarse_mask_head.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv.cnn import ConvModule, Linear
from mmcv.runner import ModuleList, auto_fp16
from mmengine.model import ModuleList
from torch import Tensor

from mmdet.core.utils import MultiConfig
from mmdet.registry import MODELS
from .fcn_mask_head import FCNMaskHead

Expand All @@ -14,29 +16,29 @@ class CoarseMaskHead(FCNMaskHead):
the input feature map instead of upsample it.
Args:
num_convs (int): Number of conv layers in the head. Default: 0.
num_fcs (int): Number of fc layers in the head. Default: 2.
num_convs (int): Number of conv layers in the head. Defaults to 0.
num_fcs (int): Number of fc layers in the head. Defaults to 2.
fc_out_channels (int): Number of output channels of fc layer.
Default: 1024.
Defaults to 1024.
downsample_factor (int): The factor that feature map is downsampled by.
Default: 2.
Defaults to 2.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""

def __init__(self,
num_convs=0,
num_fcs=2,
fc_out_channels=1024,
downsample_factor=2,
init_cfg=dict(
num_convs: int = 0,
num_fcs: int = 2,
fc_out_channels: int = 1024,
downsample_factor: int = 2,
init_cfg: MultiConfig = dict(
type='Xavier',
override=[
dict(name='fcs'),
dict(type='Constant', val=0.001, name='fc_logits')
]),
*arg,
**kwarg):
super(CoarseMaskHead, self).__init__(
**kwarg) -> None:
super().__init__(
*arg,
num_convs=num_convs,
upsample_cfg=dict(type=None),
Expand Down Expand Up @@ -81,11 +83,19 @@ def __init__(self,
output_channels = self.num_classes * self.output_area
self.fc_logits = Linear(last_layer_dim, output_channels)

def init_weights(self):
def init_weights(self) -> None:
"""Initialize weights."""
super(FCNMaskHead, self).init_weights()

@auto_fp16()
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward features from the upstream network.
Args:
x (Tensor): Extract mask RoI features.
Returns:
Tensor: Predicted foreground masks.
"""
for conv in self.convs:
x = conv(x)

Expand Down
Loading

0 comments on commit 16e47a4

Please sign in to comment.