Skip to content

Commit

Permalink
[Enhance] RepVGG for YOLOX-PAI. (open-mmlab#1025)
Browse files Browse the repository at this point in the history
* repvgg add ppf for yoloxpai

* fix by review

* update stem_channels

* fix doc

Co-authored-by: Ezra-Yu <18586273+Ezra-Yu@users.noreply.github.com>
  • Loading branch information
okotaku and Ezra-Yu authored Sep 30, 2022
1 parent 0143e5f commit 8c7b7b1
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 22 deletions.
103 changes: 92 additions & 11 deletions mmcls/models/backbones/repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
build_norm_layer)
from mmcv.runner import BaseModule, Sequential
from mmcv.utils.parrots_wrapper import _BatchNorm
from torch import nn

from ..builder import BACKBONES
from ..utils.se_layer import SELayer
Expand Down Expand Up @@ -254,6 +256,51 @@ def _norm_to_conv3x3(self, branch_norm):
return tmp_conv3x3


class MTSPPF(nn.Module):
"""MTSPPF block for YOLOX-PAI RepVGG backbone.
Args:
in_channels (int): The input channels of the block.
out_channels (int): The output channels of the block.
norm_cfg (dict): dictionary to construct and config norm layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
kernel_size (int): Kernel size of pooling. Default: 5.
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
kernel_size=5):
super().__init__()
hidden_features = in_channels // 2 # hidden channels
self.conv1 = ConvModule(
in_channels,
hidden_features,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
hidden_features * 4,
out_channels,
1,
stride=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.maxpool = nn.MaxPool2d(
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)

def forward(self, x):
x = self.conv1(x)
y1 = self.maxpool(x)
y2 = self.maxpool(y1)
return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))


@BACKBONES.register_module()
class RepVGG(BaseBackbone):
"""RepVGG backbone.
Expand All @@ -262,17 +309,24 @@ class RepVGG(BaseBackbone):
<https://arxiv.org/abs/2101.03697>`_
Args:
arch (str | dict): The parameter of RepVGG.
If it's a dict, it should contain the following keys:
arch (str | dict): RepVGG architecture. If use string,
choose from 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2'
, 'B2g2', 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict,
it should have below keys:
- num_blocks (Sequence[int]): Number of blocks in each stage.
- width_factor (Sequence[float]): Width deflator in each stage.
- group_layer_map (dict | None): RepVGG Block that declares
the need to apply group convolution.
- se_cfg (dict | None): Se Layer config
- se_cfg (dict | None): Se Layer config.
- stem_channels (int, optional): The stem channels, the final
stem channels will be
``min(stem_channels, base_channels*width_factor[0])``.
If not set here, 64 is used by default in the code.
in_channels (int): Number of input image channels. Default: 3.
base_channels (int): Base channels of RepVGG backbone, work
with width_factor together. Default: 64.
base_channels (int): Base channels of RepVGG backbone, work with
width_factor together. Defaults to 64.
out_indices (Sequence[int]): Output from which stages. Default: (3, ).
strides (Sequence[int]): Strides of the first block of each stage.
Default: (2, 2, 2, 2).
Expand All @@ -292,6 +346,7 @@ class RepVGG(BaseBackbone):
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
add_ppf (bool): Whether to use the MTSPPF block. Default: False.
init_cfg (dict or list[dict], optional): Initialization config dict.
"""

Expand Down Expand Up @@ -323,7 +378,8 @@ class RepVGG(BaseBackbone):
num_blocks=[4, 6, 16, 1],
width_factor=[1, 1, 1, 2.5],
group_layer_map=None,
se_cfg=None),
se_cfg=None,
stem_channels=64),
'B1':
dict(
num_blocks=[4, 6, 16, 1],
Expand Down Expand Up @@ -383,7 +439,14 @@ class RepVGG(BaseBackbone):
num_blocks=[8, 14, 24, 1],
width_factor=[2.5, 2.5, 2.5, 5],
group_layer_map=None,
se_cfg=dict(ratio=16, divisor=1))
se_cfg=dict(ratio=16, divisor=1)),
'yolox-pai-small':
dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=32),
}

def __init__(self,
Expand All @@ -400,6 +463,7 @@ def __init__(self,
with_cp=False,
deploy=False,
norm_eval=False,
add_ppf=False,
init_cfg=[
dict(type='Kaiming', layer=['Conv2d']),
dict(
Expand Down Expand Up @@ -427,9 +491,9 @@ def __init__(self,
if arch['se_cfg'] is not None:
assert isinstance(arch['se_cfg'], dict)

self.base_channels = base_channels
self.arch = arch
self.in_channels = in_channels
self.base_channels = base_channels
self.out_indices = out_indices
self.strides = strides
self.dilations = dilations
Expand All @@ -441,7 +505,12 @@ def __init__(self,
self.with_cp = with_cp
self.norm_eval = norm_eval

channels = min(64, int(base_channels * self.arch['width_factor'][0]))
# defaults to 64 to prevert BC-breaking if stem_channels
# not in arch dict;
# the stem channels should not be larger than that of stage1.
channels = min(
arch.get('stem_channels', 64),
int(self.base_channels * self.arch['width_factor'][0]))
self.stem = RepVGGBlock(
self.in_channels,
channels,
Expand All @@ -459,7 +528,7 @@ def __init__(self,
num_blocks = self.arch['num_blocks'][i]
stride = self.strides[i]
dilation = self.dilations[i]
out_channels = int(base_channels * 2**i *
out_channels = int(self.base_channels * 2**i *
self.arch['width_factor'][i])

stage, next_create_block_idx = self._make_stage(
Expand All @@ -471,6 +540,16 @@ def __init__(self,

channels = out_channels

if add_ppf:
self.ppf = MTSPPF(
out_channels,
out_channels,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
kernel_size=5)
else:
self.ppf = None

def _make_stage(self, in_channels, out_channels, num_blocks, stride,
dilation, next_create_block_idx, init_cfg):
strides = [stride] + [1] * (num_blocks - 1)
Expand Down Expand Up @@ -507,6 +586,8 @@ def forward(self, x):
for i, stage_name in enumerate(self.stages):
stage = getattr(self, stage_name)
x = stage(x)
if i + 1 == len(self.stages) and self.ppf is not None:
x = self.ppf(x)
if i in self.out_indices:
outs.append(x)

Expand Down
78 changes: 67 additions & 11 deletions tests/test_models/test_backbones/test_repvgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,36 @@ def test_repvgg_backbone():
# Test RepVGG forward with layer 3 forward
model = RepVGG('A0', out_indices=(3, ))
model.init_weights()
model.train()
model.eval()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 224, 224)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 1, 1))

# Test with custom arch
cfg = dict(
num_blocks=[3, 5, 7, 3],
width_factor=[1, 1, 1, 1],
group_layer_map=None,
se_cfg=None,
stem_channels=16)
model = RepVGG(arch=cfg, out_indices=(3, ))
model.eval()
assert model.stem.out_channels == min(16, 64 * 1)

imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
assert feat[0].shape == torch.Size((1, 512, 1, 1))

# Test RepVGG forward
model_test_settings = [
Expand All @@ -233,31 +251,31 @@ def test_repvgg_backbone():
dict(model_name='D2se', out_sizes=(160, 320, 640, 2560))
]

choose_models = ['A0', 'B1', 'B1g2', 'D2se']
choose_models = ['A0', 'B1', 'B1g2']
# Test RepVGG model forward
for model_test_setting in model_test_settings:
if model_test_setting['model_name'] not in choose_models:
continue
model = RepVGG(
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
model.init_weights()
model.eval()

# Test Norm
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

model.train()
imgs = torch.randn(1, 3, 224, 224)
imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert feat[0].shape == torch.Size(
(1, model_test_setting['out_sizes'][0], 56, 56))
(1, model_test_setting['out_sizes'][0], 8, 8))
assert feat[1].shape == torch.Size(
(1, model_test_setting['out_sizes'][1], 28, 28))
(1, model_test_setting['out_sizes'][1], 4, 4))
assert feat[2].shape == torch.Size(
(1, model_test_setting['out_sizes'][2], 14, 14))
(1, model_test_setting['out_sizes'][2], 2, 2))
assert feat[3].shape == torch.Size(
(1, model_test_setting['out_sizes'][3], 7, 7))
(1, model_test_setting['out_sizes'][3], 1, 1))

# Test eval of "train" mode and "deploy" mode
gap = nn.AdaptiveAvgPool2d(output_size=(1))
Expand All @@ -275,11 +293,49 @@ def test_repvgg_backbone():
torch.allclose(feat[i], feat_deploy[i])
torch.allclose(pred, pred_deploy)

# Test RepVGG forward with add_ppf
model = RepVGG('A0', out_indices=(3, ), add_ppf=True)
model.init_weights()
model.train()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))

# Test RepVGG forward with 'stem_channels' not in arch
arch = dict(
num_blocks=[2, 4, 14, 1],
width_factor=[0.75, 0.75, 0.75, 2.5],
group_layer_map=None,
se_cfg=None)
model = RepVGG(arch, add_ppf=True)
model.stem.in_channels = min(64, 64 * 0.75)
model.init_weights()
model.train()

for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)

imgs = torch.randn(1, 3, 64, 64)
feat = model(imgs)
assert isinstance(feat, tuple)
assert len(feat) == 1
assert isinstance(feat[0], torch.Tensor)
assert feat[0].shape == torch.Size((1, 1280, 2, 2))


def test_repvgg_load():
# Test output before and load from deploy checkpoint
model = RepVGG('A1', out_indices=(0, 1, 2, 3))
inputs = torch.randn((1, 3, 224, 224))
inputs = torch.randn((1, 3, 32, 32))
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
model.switch_to_deploy()
model.eval()
Expand Down

0 comments on commit 8c7b7b1

Please sign in to comment.