Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] support modified resnet structure used in oCLIP #1458

Merged
merged 4 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmocr/models/common/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .clip_resnet import CLIPResNet
from .unet import UNet

__all__ = ['UNet']
__all__ = ['UNet', 'CLIPResNet']
100 changes: 100 additions & 0 deletions mmocr/models/common/backbones/clip_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) OpenMMLab. All rights reserved.

import torch.nn as nn
from mmdet.models.backbones import ResNet
from mmdet.models.backbones.resnet import Bottleneck

from mmocr.registry import MODELS


class CLIPBottleneck(Bottleneck):
Harold-lkk marked this conversation as resolved.
Show resolved Hide resolved
HannibalAPE marked this conversation as resolved.
Show resolved Hide resolved
"""Bottleneck for CLIPResNet.

It is a Bottleneck variant used in the ResNet variant of CLIP. After the
second convolution layer, there is an additional average pooling layer with
kernel_size 2 and stride 2, which is added as a plugin when the
input stride > 1. The stride of each convolution layer is always set to 1.

Args:
**kwargs: Keyword arguments for
:class:``mmdet.models.backbones.resnet.Bottleneck``.
"""

def __init__(self, **kwargs):
stride = kwargs.get('stride', 1)
kwargs['stride'] = 1
plugins = kwargs.get('plugins', None)
if stride > 1:
if plugins is None:
plugins = []

plugins.insert(
0,
dict(
cfg=dict(type='mmocr.AvgPool2d', kernel_size=2),
position='after_conv2'))
kwargs['plugins'] = plugins
super().__init__(**kwargs)


@MODELS.register_module()
class CLIPResNet(ResNet):
HannibalAPE marked this conversation as resolved.
Show resolved Hide resolved
"""Implement the ResNet variant used in `oCLIP.

<https://github.com/bytedance/oclip>`_.
gaotongxiao marked this conversation as resolved.
Show resolved Hide resolved

It is also the official structure in
`CLIP <https://github.com/openai/CLIP>`_.

Compared with ResNetV1d structure, CLIPResNet replaces the
max pooling layer with an average pooling layer at the end
of the input stem.

In the Bottleneck of CLIPResNet, after the second convolution
layer, there is an additional average pooling layer with
kernel_size 2 and stride 2, which is added as a plugin
when the input stride > 1.
The stride of each convolution layer is always set to 1.

Args:
depth (int): Depth of resnet, options are [50]. Defaults to 50.
strides (sequence(int)): Strides of the first block of each stage.
Defaults to (1, 2, 2, 2).
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to True.
avg_down (bool): Use AvgPool instead of stride conv at
the downsampling stage in the bottleneck. Defaults to True.
**kwargs: Keyword arguments for
:class:``mmdet.models.backbones.resnet.ResNet``.
"""
Harold-lkk marked this conversation as resolved.
Show resolved Hide resolved
arch_settings = {
50: (CLIPBottleneck, (3, 4, 6, 3)),
}

def __init__(self,
depth=50,
strides=(1, 2, 2, 2),
deep_stem=True,
avg_down=True,
**kwargs):
super().__init__(
depth=depth,
strides=strides,
deep_stem=deep_stem,
avg_down=avg_down,
**kwargs)

def _make_stem_layer(self, in_channels: int, stem_channels: int):
"""Build stem layer for CLIPResNet used in `CLIP
https://github.com/openai/CLIP>`_.

It uses an average pooling layer rather than a max pooling
layer at the end of the input stem.

Args:
in_channels (int): Number of input channels.
stem_channels (int): Number of output channels.
"""
super()._make_stem_layer(in_channels, stem_channels)
if self.deep_stem:
self.maxpool = nn.AvgPool2d(kernel_size=2)
4 changes: 4 additions & 0 deletions mmocr/models/common/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .common import AvgPool2d

__all__ = ['AvgPool2d']
40 changes: 40 additions & 0 deletions mmocr/models/common/plugins/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn

from mmocr.registry import MODELS


@MODELS.register_module()
class AvgPool2d(nn.Module):
"""Applies a 2D average pooling over an input signal composed of several
input planes.

It can also be used as a network plugin.

Args:
kernel_size (int or tuple(int)): the size of the window.
stride (int or tuple(int), optional): the stride of the window.
Defaults to None.
padding (int or tuple(int)): implicit zero padding. Defaults to 0.
"""

def __init__(self,
kernel_size: Union[int, Tuple[int]],
stride: Optional[Union[int, Tuple[int]]] = None,
padding: Union[int, Tuple[int]] = 0,
**kwargs) -> None:
super().__init__()
self.model = nn.AvgPool2d(kernel_size, stride, padding)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward function.
Args:
x (Tensor): Input feature map.

Returns:
Tensor: Output tensor after Avgpooling layer.
"""
return self.model(x)
66 changes: 66 additions & 0 deletions tests/test_models/test_common/test_backbones/test_clip_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch
import torch.nn as nn
from mmcv.cnn import build_conv_layer, build_norm_layer

from mmocr.models.common.backbones import CLIPResNet
from mmocr.models.common.backbones.clip_resnet import CLIPBottleneck


class TestCLIPResNet(TestCase):

def test_forward(self):
model = CLIPResNet()
model.eval()

imgs = torch.randn(1, 3, 32, 32)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 256, 8, 8])
assert feat[1].shape == torch.Size([1, 512, 4, 4])
assert feat[2].shape == torch.Size([1, 1024, 2, 2])
assert feat[3].shape == torch.Size([1, 2048, 1, 1])


class TestCLIPBottleneck(TestCase):

def test_forward(self):
Harold-lkk marked this conversation as resolved.
Show resolved Hide resolved
stride = 2
inplanes = 256
planes = 128
conv_cfg = None
norm_cfg = {'type': 'BN', 'requires_grad': True}

downsample = []
downsample.append(
nn.AvgPool2d(
kernel_size=stride,
stride=stride,
ceil_mode=True,
count_include_pad=False))
downsample.extend([
build_conv_layer(
conv_cfg,
inplanes,
planes * CLIPBottleneck.expansion,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(norm_cfg, planes * CLIPBottleneck.expansion)[1]
])
downsample = nn.Sequential(*downsample)

model = CLIPBottleneck(
inplanes=inplanes,
planes=planes,
stride=stride,
downsample=downsample,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg)
model.eval()

input_feat = torch.randn(1, 256, 8, 8)
output_feat = model(input_feat)
assert output_feat.shape == torch.Size([1, 512, 4, 4])
16 changes: 16 additions & 0 deletions tests/test_models/test_common/test_plugins/test_avgpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase

import torch

from mmocr.models.common.plugins import AvgPool2d


class TestAvgPool2d(TestCase):

def setUp(self) -> None:
self.img = torch.rand(1, 3, 32, 100)

def test_avgpool2d(self):
avgpool2d = AvgPool2d(kernel_size=2, stride=2)
self.assertEqual(avgpool2d(self.img).shape, torch.Size([1, 3, 16, 50]))