Skip to content

Commit

Permalink
[Refactor]: Unified parameter initialization (open-mmlab#567)
Browse files Browse the repository at this point in the history
* [Refactor]: Unified parameter initialization

* fixed pretrained
  • Loading branch information
xvjiarui authored Jun 17, 2021
1 parent af6478d commit 9849a8d
Show file tree
Hide file tree
Showing 19 changed files with 329 additions and 298 deletions.
60 changes: 32 additions & 28 deletions mmseg/models/backbones/cgnet.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import warnings

import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_conv_layer, build_norm_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.utils import get_root_logger
from ..builder import BACKBONES


Expand Down Expand Up @@ -183,7 +183,7 @@ def forward(self, x):


@BACKBONES.register_module()
class CGNet(nn.Module):
class CGNet(BaseModule):
"""CGNet backbone.
A Light-weight Context Guided Network for Semantic Segmentation
Expand All @@ -210,6 +210,9 @@ class CGNet(nn.Module):
and its variants only. Default: False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""

def __init__(self,
Expand All @@ -222,9 +225,31 @@ def __init__(self,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='PReLU'),
norm_eval=False,
with_cp=False):
with_cp=False,
pretrained=None,
init_cfg=None):

super(CGNet, self).__init__(init_cfg)

assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer=['Conv2d', 'Linear']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm']),
dict(type='Constant', val=0, layer='PReLU')
]
else:
raise TypeError('pretrained must be a str or None')

super(CGNet, self).__init__()
self.in_channels = in_channels
self.num_channels = num_channels
assert isinstance(self.num_channels, tuple) and len(
Expand Down Expand Up @@ -335,27 +360,6 @@ def forward(self, x):

return output

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.Linear)):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, nn.PReLU):
constant_init(m, 0)
else:
raise TypeError('pretrained must be a str or None')

def train(self, mode=True):
"""Convert the model into training mode will keeping the normalization
layer freezed."""
Expand Down
29 changes: 16 additions & 13 deletions mmseg/models/backbones/fast_scnn.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, constant_init,
kaiming_init)
from torch.nn.modules.batchnorm import _BatchNorm
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
from mmcv.runner import BaseModule

from mmseg.models.decode_heads.psp_head import PPM
from mmseg.ops import resize
Expand Down Expand Up @@ -247,7 +246,7 @@ def forward(self, higher_res_feature, lower_res_feature):


@BACKBONES.register_module()
class FastSCNN(nn.Module):
class FastSCNN(BaseModule):
"""Fast-SCNN Backbone.
Args:
Expand Down Expand Up @@ -291,6 +290,8 @@ class FastSCNN(nn.Module):
dict(type='ReLU')
align_corners (bool): align_corners argument of F.interpolate.
Default: False
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""

def __init__(self,
Expand All @@ -307,9 +308,18 @@ def __init__(self,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
align_corners=False):
align_corners=False,
init_cfg=None):

super(FastSCNN, self).__init__(init_cfg)

if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]

super(FastSCNN, self).__init__()
if global_in_channels != higher_in_channels:
raise AssertionError('Global Input Channels must be the same \
with Higher Input Channels!')
Expand Down Expand Up @@ -357,13 +367,6 @@ def __init__(self,
act_cfg=self.act_cfg,
align_corners=self.align_corners)

def init_weights(self, pretrained=None):
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

def forward(self, x):
higher_res_features = self.learning_to_downsample(x)
lower_res_features = self.global_feature_extractor(higher_res_features)
Expand Down
119 changes: 73 additions & 46 deletions mmseg/models/backbones/hrnet.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import warnings

import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule, ModuleList, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm

from mmseg.ops import Upsample, resize
from mmseg.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck


class HRModule(nn.Module):
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.
In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
Expand All @@ -26,8 +26,11 @@ def __init__(self,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HRModule, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
block_init_cfg=None,
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self.block_init_cfg = block_init_cfg
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)

Expand Down Expand Up @@ -92,7 +95,8 @@ def _make_one_branch(self,
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))
self.in_channels[branch_index] = \
num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
Expand All @@ -102,9 +106,10 @@ def _make_one_branch(self,
num_channels[branch_index],
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=self.block_init_cfg))

return nn.Sequential(*layers)
return Sequential(*layers)

def _make_branches(self, num_branches, block, num_blocks, num_channels):
"""Build multiple branch."""
Expand All @@ -114,7 +119,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))

return nn.ModuleList(branches)
return ModuleList(branches)

def _make_fuse_layers(self):
"""Build fuse layer."""
Expand Down Expand Up @@ -209,7 +214,7 @@ def forward(self, x):


@BACKBONES.register_module()
class HRNet(nn.Module):
class HRNet(BaseModule):
"""HRNet backbone.
High-Resolution Representations for Labeling Pixels and Regions
Expand All @@ -227,6 +232,9 @@ class HRNet(nn.Module):
memory while slowing down the training speed.
zero_init_residual (bool): whether to use zero init for last norm layer
in resblocks to let them behave as identity.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
Example:
>>> from mmseg.models import HRNet
Expand Down Expand Up @@ -277,14 +285,36 @@ def __init__(self,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=False,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
zero_init_residual=False,
pretrained=None,
init_cfg=None):
super(HRNet, self).__init__(init_cfg)

self.pretrained = pretrained
self.zero_init_residual = zero_init_residual
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')

self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.zero_init_residual = zero_init_residual

# stem net
self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
Expand Down Expand Up @@ -430,6 +460,16 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])

layers = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))

layers.append(
block(
inplanes,
Expand All @@ -438,7 +478,8 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample=downsample,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(
Expand All @@ -447,9 +488,10 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
planes,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))
conv_cfg=self.conv_cfg,
init_cfg=block_init_cfg))

return nn.Sequential(*layers)
return Sequential(*layers)

def _make_stage(self, layer_config, in_channels, multiscale_output=True):
"""Make each stage."""
Expand All @@ -460,6 +502,16 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
block = self.blocks_dict[layer_config['block']]

hr_modules = []
block_init_cfg = None
if self.pretrained is None and not hasattr(
self, 'init_cfg') and self.zero_init_residual:
if block is BasicBlock:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm2'))
elif block is Bottleneck:
block_init_cfg = dict(
type='Constant', val=0, override=dict(name='norm3'))

for i in range(num_modules):
# multi_scale_output is only used for the last module
if not multiscale_output and i == num_modules - 1:
Expand All @@ -477,35 +529,10 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
reset_multiscale_output,
with_cp=self.with_cp,
norm_cfg=self.norm_cfg,
conv_cfg=self.conv_cfg))

return nn.Sequential(*hr_modules), in_channels
conv_cfg=self.conv_cfg,
block_init_cfg=block_init_cfg))

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')
return Sequential(*hr_modules), in_channels

def forward(self, x):
"""Forward function."""
Expand Down
Loading

0 comments on commit 9849a8d

Please sign in to comment.