forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add segformer decode head and related train config (open-mm…
…lab#599) * [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * [Feature] Add segformer decode head and related train config * Add ade20K trainval support for segformer 1. Add related train and val configs; 2. Add AlignedResize; * Set arg: find_unused_parameters = True * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * Replace Linear Layer to 1X1 Conv * Use nn.ModuleList to refactor segformer head. * Remove local to_xtuple * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * Modify configs of segformer. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Support progressive test with fewer memory cost. * Modify default value of pad_to_patch_size arg. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Fix some bugs about model loading and eval hook. * Add ade20k 640x640 dataset. * Fix related segformer configs. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Modify error patch size. * Fix pretrain of mit_b0 * Fix the test api error. * Modify dataset base config. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Add part of benchmark results. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * Update readme. * Update readme of segformer. * Updata readme of segformer. * Update segformer readme and fix segformer mit_b4. * Update readme of segformer. * Clean AlignedResize related config. * Clean code from pr open-mmlab#709 * Clean code from pr open-mmlab#709 * Add 512x512 segformer_mit-b5. * Fix lint. * Fix some segformer head bugs. * Add segformer unit tests. * Replace AlignedResize to ResizeToMultiple. * Modify readme of segformer. * Fix bug of ResizeToMultiple. * Add ResizeToMultiple unit tests. * Resolve conflict. * Simplify the implementation of ResizeToMultiple. * Update test results. * Fix multi-scale test error when resize_ratio=1.75 and input size=640x640. * Update segformer results. * Update Segformer results. * Fix some url bugs and pipelines bug. * Move ckpt convertion to tools. * Add segformer official pretrain weights usage. * Clean redundant codes. * Remove redundant codes. * Unfied format. * Add description for segformer converter. * Update workers.
- Loading branch information
Showing
18 changed files
with
494 additions
and
62 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained=None, | ||
backbone=dict( | ||
type='MixVisionTransformer', | ||
in_channels=3, | ||
embed_dims=32, | ||
num_stages=4, | ||
num_layers=[2, 2, 2, 2], | ||
num_heads=[1, 2, 5, 8], | ||
patch_sizes=[7, 3, 3, 3], | ||
sr_ratios=[8, 4, 2, 1], | ||
out_indices=(0, 1, 2, 3), | ||
mlp_ratio=4, | ||
qkv_bias=True, | ||
drop_rate=0.0, | ||
attn_drop_rate=0.0, | ||
drop_path_rate=0.1), | ||
decode_head=dict( | ||
type='SegformerHead', | ||
in_channels=[32, 64, 160, 256], | ||
in_index=[0, 1, 2, 3], | ||
channels=256, | ||
dropout_ratio=0.1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
# model training and testing settings | ||
train_cfg=dict(), | ||
test_cfg=dict(mode='whole')) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers | ||
|
||
## Introduction | ||
|
||
<!-- [ALGORITHM] --> | ||
|
||
```latex | ||
@article{xie2021segformer, | ||
title={SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers}, | ||
author={Xie, Enze and Wang, Wenhai and Yu, Zhiding and Anandkumar, Anima and Alvarez, Jose M and Luo, Ping}, | ||
journal={arXiv preprint arXiv:2105.15203}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
## Results and models | ||
|
||
### ADE20k | ||
|
||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | ||
| ------ | -------- | --------- | ------: | -------: | -------------- | ---: | ------------- | ------ | -------- | | ||
|Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 51.32 | 37.41 | 38.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530.log.json) | | ||
|Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 47.66 | 40.97 | 42.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106.log.json) | | ||
|Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 30.88 | 45.58 | 47.03 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103.log.json) | | ||
|Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 22.11 | 47.82 | 48.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410.log.json) | | ||
|Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 15.45 | 48.46 | 49.76 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055.log.json) | | ||
|Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235.log.json) | | ||
|Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 11.30 | 49.62 | 50.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243.log.json) | | ||
|
||
Evaluation with AlignedResize: | ||
|
||
| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) | | ||
| ------ | -------- | --------- | ------: | ---: | ------------- | | ||
|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 | | ||
|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 | | ||
|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 | | ||
|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 | | ||
|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 | | ||
|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 | | ||
|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 | | ||
|
||
We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by | ||
using `AlignedResize`, you can change the dataset pipeline like this: | ||
|
||
```python | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(2048, 512), | ||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
# resize image to multiple of 32, improve SegFormer by 0.5-1.0 mIoU. | ||
dict(type='ResizeToMultiple', size_divisor=32), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
``` | ||
|
||
## How to use segformer official pretrain weights | ||
|
||
We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`. | ||
|
||
You may follow below steps to start segformer training preparation: | ||
|
||
1. Download segformer pretrain weights (Suggest put in `pretrain/`); | ||
2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`; | ||
3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
_base_ = [ | ||
'../_base_/models/segformer_mit-b0.py', '../_base_/datasets/ade20k.py', | ||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' | ||
] | ||
|
||
model = dict( | ||
pretrained='pretrain/mit_b0.pth', decode_head=dict(num_classes=150)) | ||
|
||
# optimizer | ||
optimizer = dict( | ||
_delete_=True, | ||
type='AdamW', | ||
lr=0.00006, | ||
betas=(0.9, 0.999), | ||
weight_decay=0.01, | ||
paramwise_cfg=dict( | ||
custom_keys={ | ||
'pos_block': dict(decay_mult=0.), | ||
'norm': dict(decay_mult=0.), | ||
'head': dict(lr_mult=10.) | ||
})) | ||
|
||
lr_config = dict( | ||
_delete_=True, | ||
policy='poly', | ||
warmup='linear', | ||
warmup_iters=1500, | ||
warmup_ratio=1e-6, | ||
power=1.0, | ||
min_lr=0.0, | ||
by_epoch=False) | ||
|
||
data = dict(samples_per_gpu=2, workers_per_gpu=2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b1.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[2, 2, 2, 2]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b2.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 6, 3]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b3.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 18, 3]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b4.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 8, 27, 3]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b5.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] | ||
|
||
# dataset settings | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
crop_size = (640, 640) | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='LoadAnnotations', reduce_zero_label=True), | ||
dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), | ||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), | ||
dict(type='RandomFlip', prob=0.5), | ||
dict(type='PhotoMetricDistortion'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), | ||
dict(type='DefaultFormatBundle'), | ||
dict(type='Collect', keys=['img', 'gt_semantic_seg']), | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='MultiScaleFlipAug', | ||
img_scale=(2048, 640), | ||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], | ||
flip=False, | ||
transforms=[ | ||
dict(type='Resize', keep_ratio=True), | ||
dict(type='RandomFlip'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='ImageToTensor', keys=['img']), | ||
dict(type='Collect', keys=['img']), | ||
]) | ||
] | ||
data = dict( | ||
train=dict(pipeline=train_pipeline), | ||
val=dict(pipeline=test_pipeline), | ||
test=dict(pipeline=test_pipeline)) | ||
|
||
# model settings | ||
model = dict( | ||
pretrained='pretrain/mit_b5.pth', | ||
backbone=dict( | ||
embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]), | ||
decode_head=dict(in_channels=[64, 128, 320, 512])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.