Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Segformer backbone re-implementation #594

Merged
merged 31 commits into from
Jul 19, 2021

Conversation

clownrat6
Copy link
Contributor

@clownrat6 clownrat6 commented Jun 9, 2021

  • Add backbone MixVisionTransformer;

- [ ] Add head SegFormerHead;

- [ ] Add dataset transform pipeline AlignedResize;

- [ ] Add some config for segformer;

@codecov
Copy link

codecov bot commented Jun 9, 2021

Codecov Report

Merging #594 (e7039ff) into master (e610ed1) will increase coverage by 0.09%.
The diff coverage is 71.42%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #594      +/-   ##
==========================================
+ Coverage   85.18%   85.28%   +0.09%     
==========================================
  Files         105      107       +2     
  Lines        5671     5817     +146     
  Branches      923      951      +28     
==========================================
+ Hits         4831     4961     +130     
- Misses        662      673      +11     
- Partials      178      183       +5     
Flag Coverage Δ
unittests 85.26% <71.42%> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmseg/models/backbones/swin.py 86.89% <ø> (ø)
mmseg/models/backbones/vit.py 84.84% <ø> (ø)
mmseg/models/utils/ckpt_convert.py 4.42% <4.65%> (+0.13%) ⬆️
mmseg/models/backbones/mit.py 88.80% <88.80%> (ø)
mmseg/models/backbones/__init__.py 100.00% <100.00%> (ø)
mmseg/models/utils/__init__.py 100.00% <100.00%> (ø)
mmseg/models/utils/embed.py 81.57% <100.00%> (+1.02%) ⬆️
mmseg/models/utils/shape_convert.py 100.00% <100.00%> (ø)
... and 1 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e610ed1...e7039ff. Read the comment docs.

@clownrat6 clownrat6 changed the title [Feature]Segformer re-implementation [Feature]Segformer backbone re-implementation Jun 10, 2021
@Junjun2016 Junjun2016 changed the title [Feature]Segformer backbone re-implementation [Feature] Segformer backbone re-implementation Jun 18, 2021
self.dwconv = DWConv(feedforward_channels)
self.act = build_activation_layer(act_cfg)
self.fc2 = Linear(feedforward_channels, in_channels)
self.drop = nn.Dropout(drop_rate)
Copy link
Collaborator

@Junjun2016 Junjun2016 Jun 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can replace FC with 1x1 conv, so we can avoid dimension transform and integrate depthwise conv in MLP.

return x


class Mlp(BaseModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be careful that the variable or class name needs to be consistent with the paper.
Rename Mlp to MixFFN

return x


class Attention(BaseModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename Attention to EfficientMultiheadAttention.
We can also inherit from MultiheadAttention in MMCV and pass different query, key, and value in forward phrase according to different spatial reduction.

Comment on lines 89 to 96
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = ConvModule(
in_channels=dim,
out_channels=dim,
kernel_size=sr_ratio,
stride=sr_ratio)
_, self.norm = build_norm_layer(norm_cfg, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we inherit from MultiheadAttention in MMCV, we only need to add these lines.

return x


class Block(BaseModule):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rename it.

proj_drop=drop_rate,
sr_ratio=sr_ratio)
# NOTE: drop path for stochastic depth, we shall see if this is better
# than dropout here
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment necessary?

Comment on lines 180 to 187
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)

self.img_size = img_size
self.patch_size = patch_size
num_rows, num_cols = img_size[0] // patch_size[0], img_size[
1] // patch_size[1]
self.num_patches = num_rows * num_cols
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these lines necessary?

num_rows, num_cols = img_size[0] // patch_size[0], img_size[
1] // patch_size[1]
self.num_patches = num_rows * num_cols
self.proj = nn.Conv2d(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ConvModule

self.pretrained = pretrained
self.depths = depths
# patch_embed
self.patch_embed1 = OverlapPatchEmbed(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ModuleList

strict=False,
logger=logger)

def reset_drop_path(self, drop_path_rate):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this used?

sennnnn added 2 commits July 10, 2021 01:36
@@ -0,0 +1,10 @@
def nlc_to_nchw(tensor, H, W):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring.

Comment on lines 305 to 343
A PyTorch implement of : `SegFormer: Simple and Efficient Design for
Semantic Segmentation with Transformers` -
https://arxiv.org/pdf/2105.15203.pdf

in_channels (int): Number of input channels. Default: 3.
embed_dims (int): Embedding dimension. Default: 768.
num_stags (int): The num of stages. Default: 4.
num_layers (list[int]): The layer number of each transformer encode
layer. Default: [3, 4, 6, 3].
num_heads (list[int]): The attention heads of each transformer
encode layer. Default: [1, 2, 4, 8].
patch_sizes (list[int]): The patch_size of each overlapped patch embedding.
Default: [7, 3, 3, 3].
strides (list[int]): The stride of each overlapped patch embedding.
Default: [4, 2, 2, 2].
sr_ratios (list[int]): The spatial reduction rate of each transformer
encode layer. Default: [8, 4, 2, 1].
out_indices (list[int] | tuple[int] | int): Output from which stages.
Default: (0, 1, 2, 3).
mlp_ratio (int): ratio of mlp hidden dim to embedding dim.
Default: 4.
out_indices (list | tuple | int): Output from which stages.
Default: -1.
qkv_bias (bool): Enable bias for qkv if True. Default: True.
drop_rate (float): Probability of an element to be zeroed.
Default 0.0
attn_drop_rate (float): The drop out rate for attention layer.
Default 0.0
drop_path_rate (float): stochastic depth rate. Default 0.0
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='LN')
act_cfg (dict): The activation config for FFNs.
Defalut: dict(type='GELU').
pretrain_style (str): Choose to use official or mmcls pretrain weights.
Default: official.
pretrained (str, optional): model pretrained path. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The format of the docstring is not correct.

dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)

def forward(self, x, H, W):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, x, H, W):
def forward(self, x, hw_shape):

Comment on lines 99 to 105
conv1x1 = partial(
ConvModule,
kernel_size=1,
stride=1,
bias=True,
norm_cfg=None,
act_cfg=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use partial.

from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw


class PEConv(BaseModule):
Copy link
Collaborator

@xvjiarui xvjiarui Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to use PEConv to wrap conv?

ffn_drop=0.,
pe_index=1,
dropout_layer=None,
add_identity=True,
Copy link
Collaborator

@xvjiarui xvjiarui Jul 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is add_identity argument necessary?

Comment on lines 109 to 134
# first position of MixFFN
if pe_index == 0:
layers.append(PEConv(in_channels))
for idx in range(num_fcs - 1):
container = []
container.append(
conv1x1(
in_channels=in_channels,
out_channels=feedforward_channels))
# middle position of MixFFN
if pe_index == idx + 1:
container.append(PEConv(feedforward_channels))
container.append(self.activate)
container.append(nn.Dropout(ffn_drop))
layers.append(Sequential(*container))
layers.append(
conv1x1(
in_channels=feedforward_channels, out_channels=in_channels))
# Last position of MixFFN
if pe_index == num_fcs:
layers.append(PEConv(feedforward_channels))
layers.append(nn.Dropout(ffn_drop))
self.layers = Sequential(*layers)
self.dropout_layer = build_dropout(
dropout_layer) if dropout_layer else torch.nn.Identity()
self.add_identity = add_identity
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is too long.
Consider simplifying it.

num_fcs=2,
act_cfg=dict(type='GELU'),
ffn_drop=0.,
pe_index=1,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it is too complicated, we may remove pe_index.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just fix thenum_fcs to 2, and insert PE in the middle.

Comment on lines 6 to 7
hw_shape (Sequence[int]): The height and width of output feature map.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return

"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor.

Args:
x (Tensor): The input tensor for convertion.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return

(0, self.patch_size[1] - W % self.patch_size[1], 0, 0))

# TODO: Process overlapping op
if not self.overlapping:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overlapping is not precise for this.
We may consider something like auto_pad or pad_to_patch_size.

We also need to make it an argument and add a docstring for that.

Comment on lines 22 to 23
pad_to_patch_size (bool, optional): Whether to pad feature map shape
to multiple patch size. Default: False.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may make it True by default.

@xvjiarui xvjiarui merged commit 4d34581 into open-mmlab:master Jul 19, 2021
bowenroom pushed a commit to bowenroom/mmsegmentation that referenced this pull request Feb 25, 2022
* [Feature]Segformer re-implementation

* Using act_cfg and norm_cfg to control activation and normalization

* Split this PR into several little PRs

* Fix lint error

* Remove SegFormerHead

* parameters init refactor

* 1. Refactor segformer backbone parameters init;

2. Remove rebundant functions and unit tests;

* Remove rebundant codes

* 1. Remove rebundant codes;

2. Modify module name;

* Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py

* Fix some code logic bugs.

* Add mit_convert.py to match pretrain keys of segformer.

* Resolve some comments.

* 1. Add some assert to ensure right params;

2. Support flexible peconv position;

* Add pe_index assert and fix unit test.

* 1. Add doc string for MixVisionTransformer;

2. Add some unit tests for MixVisionTransformer;

* Use hw_shape to pass shape of feature map.

* 1. Fix doc string of MixVisionTransformer;

2. Simplify MixFFN;

3. Modify H, W to hw_shape;

* Add more unit tests.

* Add doc string for shape convertion functions.

* Add some unit tests to improve code coverage.

* Fix Segformer backbone pretrain weights match bug.

* resolve the shape convertion functions doc string.

* Add pad_to_patch_size arg.

* Modify default value of pad_to_patch_size arg.
aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this pull request Mar 27, 2023
* [Flax] Fix unet and ddim scheduler

* correct

* finish
sibozhang pushed a commit to sibozhang/mmsegmentation that referenced this pull request Mar 22, 2024
* Update README_cn.md

* Update README_cn.md
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants