From b4fd32d049fb910cd6765f4ae35fc3992ef2e178 Mon Sep 17 00:00:00 2001 From: sennnnn <58427300+sennnnn@users.noreply.github.com> Date: Fri, 13 Aug 2021 13:31:19 +0800 Subject: [PATCH] [Feature] Add segformer decode head and related train config (#599) * [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * [Feature] Add segformer decode head and related train config * Add ade20K trainval support for segformer 1. Add related train and val configs; 2. Add AlignedResize; * Set arg: find_unused_parameters = True * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * Replace Linear Layer to 1X1 Conv * Use nn.ModuleList to refactor segformer head. * Remove local to_xtuple * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * Modify configs of segformer. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Support progressive test with fewer memory cost. * Modify default value of pad_to_patch_size arg. * Temp code * Using processor to refactor evaluation workflow. * refactor eval hook. * Fix process bar. * Fix middle save argument. * Modify some variable name of dataset evaluate api. * Modify some viriable name of eval hook. * Fix some priority bugs of eval hook. * Fix some bugs about model loading and eval hook. * Add ade20k 640x640 dataset. * Fix related segformer configs. * Depreciated efficient_test. * Fix training progress blocked by eval hook. * Depreciated old test api. * Modify error patch size. * Fix pretrain of mit_b0 * Fix the test api error. * Modify dataset base config. * Fix test api error. * Modify outer api. * Build a sampler test api. * TODO: Refactor format_results. * Modify variable names. * Fix num_classes bug. * Fix sampler index bug. * Fix grammaly bug. * Add part of benchmark results. * Support batch sampler. * More readable test api. * Remove some command arg and fix eval hook bug. * Support format-only arg. * Modify format_results of datasets. * Modify tool which use test apis. * Update readme. * Update readme of segformer. * Updata readme of segformer. * Update segformer readme and fix segformer mit_b4. * Update readme of segformer. * Clean AlignedResize related config. * Clean code from pr #709 * Clean code from pr #709 * Add 512x512 segformer_mit-b5. * Fix lint. * Fix some segformer head bugs. * Add segformer unit tests. * Replace AlignedResize to ResizeToMultiple. * Modify readme of segformer. * Fix bug of ResizeToMultiple. * Add ResizeToMultiple unit tests. * Resolve conflict. * Simplify the implementation of ResizeToMultiple. * Update test results. * Fix multi-scale test error when resize_ratio=1.75 and input size=640x640. * Update segformer results. * Update Segformer results. * Fix some url bugs and pipelines bug. * Move ckpt convertion to tools. * Add segformer official pretrain weights usage. * Clean redundant codes. * Remove redundant codes. * Unfied format. * Add description for segformer converter. * Update workers. --- configs/_base_/models/segformer_mit-b0.py | 34 +++++++++ configs/segformer/readme.md | 73 ++++++++++++++++++ .../segformer_mit-b0_512x512_160k_ade20k.py | 33 ++++++++ .../segformer_mit-b1_512x512_160k_ade20k.py | 8 ++ .../segformer_mit-b2_512x512_160k_ade20k.py | 8 ++ .../segformer_mit-b3_512x512_160k_ade20k.py | 8 ++ .../segformer_mit-b4_512x512_160k_ade20k.py | 8 ++ .../segformer_mit-b5_512x512_160k_ade20k.py | 8 ++ .../segformer_mit-b5_640x640_160k_ade20k.py | 44 +++++++++++ mmseg/datasets/pipelines/transforms.py | 57 ++++++++++++++ mmseg/models/backbones/mit.py | 18 ++--- mmseg/models/decode_heads/__init__.py | 4 +- mmseg/models/decode_heads/segformer_head.py | 65 ++++++++++++++++ mmseg/models/utils/__init__.py | 4 +- mmseg/models/utils/ckpt_convert.py | 49 ------------ tests/test_data/test_transform.py | 20 +++++ .../test_heads/test_segformer_head.py | 39 ++++++++++ tools/model_converters/mit_convert.py | 76 +++++++++++++++++++ 18 files changed, 494 insertions(+), 62 deletions(-) create mode 100644 configs/_base_/models/segformer_mit-b0.py create mode 100644 configs/segformer/readme.md create mode 100644 configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py create mode 100644 configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py create mode 100644 mmseg/models/decode_heads/segformer_head.py create mode 100644 tests/test_models/test_heads/test_segformer_head.py create mode 100644 tools/model_converters/mit_convert.py diff --git a/configs/_base_/models/segformer_mit-b0.py b/configs/_base_/models/segformer_mit-b0.py new file mode 100644 index 0000000000..5b3e07331d --- /dev/null +++ b/configs/_base_/models/segformer_mit-b0.py @@ -0,0 +1,34 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='MixVisionTransformer', + in_channels=3, + embed_dims=32, + num_stages=4, + num_layers=[2, 2, 2, 2], + num_heads=[1, 2, 5, 8], + patch_sizes=[7, 3, 3, 3], + sr_ratios=[8, 4, 2, 1], + out_indices=(0, 1, 2, 3), + mlp_ratio=4, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1), + decode_head=dict( + type='SegformerHead', + in_channels=[32, 64, 160, 256], + in_index=[0, 1, 2, 3], + channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/segformer/readme.md b/configs/segformer/readme.md new file mode 100644 index 0000000000..cf2fece512 --- /dev/null +++ b/configs/segformer/readme.md @@ -0,0 +1,73 @@ +# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers + +## Introduction + + + +```latex +@article{xie2021segformer, + title={SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers}, + author={Xie, Enze and Wang, Wenhai and Yu, Zhiding and Anandkumar, Anima and Alvarez, Jose M and Luo, Ping}, + journal={arXiv preprint arXiv:2105.15203}, + year={2021} +} +``` + +## Results and models + +### ADE20k + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ------: | -------: | -------------- | ---: | ------------- | ------ | -------- | +|Segformer | MIT-B0 | 512x512 | 160000 | 2.1 | 51.32 | 37.41 | 38.34 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530-8ffa8fda.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b0_512x512_160k_ade20k/segformer_mit-b0_512x512_160k_ade20k_20210726_101530.log.json) | +|Segformer | MIT-B1 | 512x512 | 160000 | 2.6 | 47.66 | 40.97 | 42.54 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106-d70e859d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b1_512x512_160k_ade20k/segformer_mit-b1_512x512_160k_ade20k_20210726_112106.log.json) | +|Segformer | MIT-B2 | 512x512 | 160000 | 3.6 | 30.88 | 45.58 | 47.03 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103-cbd414ac.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b2_512x512_160k_ade20k/segformer_mit-b2_512x512_160k_ade20k_20210726_112103.log.json) | +|Segformer | MIT-B3 | 512x512 | 160000 | 4.8 | 22.11 | 47.82 | 48.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410-962b98d2.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b3_512x512_160k_ade20k/segformer_mit-b3_512x512_160k_ade20k_20210726_081410.log.json) | +|Segformer | MIT-B4 | 512x512 | 160000 | 6.1 | 15.45 | 48.46 | 49.76 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055-7f509d7d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b4_512x512_160k_ade20k/segformer_mit-b4_512x512_160k_ade20k_20210728_183055.log.json) | +|Segformer | MIT-B5 | 512x512 | 160000 | 7.2 | 11.89 | 49.13 | 50.22 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235-94cedf59.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_512x512_160k_ade20k/segformer_mit-b5_512x512_160k_ade20k_20210726_145235.log.json) | +|Segformer | MIT-B5 | 640x640 | 160000 | 11.5 | 11.30 | 49.62 | 50.36 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243-41d2845b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_640x640_160k_ade20k/segformer_mit-b5_640x640_160k_ade20k_20210801_121243.log.json) | + +Evaluation with AlignedResize: + +| Method | Backbone | Crop Size | Lr schd | mIoU | mIoU(ms+flip) | +| ------ | -------- | --------- | ------: | ---: | ------------- | +|Segformer | MIT-B0 | 512x512 | 160000 | 38.1 | 38.57 | +|Segformer | MIT-B1 | 512x512 | 160000 | 41.64 | 42.76 | +|Segformer | MIT-B2 | 512x512 | 160000 | 46.53 | 47.49 | +|Segformer | MIT-B3 | 512x512 | 160000 | 48.46 | 49.14 | +|Segformer | MIT-B4 | 512x512 | 160000 | 49.34 | 50.29 | +|Segformer | MIT-B5 | 512x512 | 160000 | 50.08 | 50.72 | +|Segformer | MIT-B5 | 640x640 | 160000 | 50.58 | 50.8 | + +We replace `AlignedResize` in original implementatiuon to `Resize + ResizeToMultiple`. If you want to test by +using `AlignedResize`, you can change the dataset pipeline like this: + +```python +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + # resize image to multiple of 32, improve SegFormer by 0.5-1.0 mIoU. + dict(type='ResizeToMultiple', size_divisor=32), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +``` + +## How to use segformer official pretrain weights + +We convert the backbone weights from the official repo (https://github.com/NVlabs/SegFormer) with `tools/model_converters/mit_convert.py`. + +You may follow below steps to start segformer training preparation: + +1. Download segformer pretrain weights (Suggest put in `pretrain/`); +2. Run convert script to convert official pretrain weights: `python tools/model_converters/mit_convert.py pretrain/mit_b0.pth pretrain/mit_b0.pth`; +3. Modify `pretrained` of segformer model config, for example, `pretrained` of `segformer_mit-b0_512x512_160k_ade20k.py` is set to `pretrain/mit_b0.pth`; diff --git a/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py new file mode 100644 index 0000000000..03065a7940 --- /dev/null +++ b/configs/segformer/segformer_mit-b0_512x512_160k_ade20k.py @@ -0,0 +1,33 @@ +_base_ = [ + '../_base_/models/segformer_mit-b0.py', '../_base_/datasets/ade20k.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] + +model = dict( + pretrained='pretrain/mit_b0.pth', decode_head=dict(num_classes=150)) + +# optimizer +optimizer = dict( + _delete_=True, + type='AdamW', + lr=0.00006, + betas=(0.9, 0.999), + weight_decay=0.01, + paramwise_cfg=dict( + custom_keys={ + 'pos_block': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.) + })) + +lr_config = dict( + _delete_=True, + policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, + min_lr=0.0, + by_epoch=False) + +data = dict(samples_per_gpu=2, workers_per_gpu=2) diff --git a/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py new file mode 100644 index 0000000000..5fce602144 --- /dev/null +++ b/configs/segformer/segformer_mit-b1_512x512_160k_ade20k.py @@ -0,0 +1,8 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# model settings +model = dict( + pretrained='pretrain/mit_b1.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[2, 2, 2, 2]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py new file mode 100644 index 0000000000..afb24b0170 --- /dev/null +++ b/configs/segformer/segformer_mit-b2_512x512_160k_ade20k.py @@ -0,0 +1,8 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# model settings +model = dict( + pretrained='pretrain/mit_b2.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 6, 3]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py new file mode 100644 index 0000000000..52348f6fcc --- /dev/null +++ b/configs/segformer/segformer_mit-b3_512x512_160k_ade20k.py @@ -0,0 +1,8 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# model settings +model = dict( + pretrained='pretrain/mit_b3.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 4, 18, 3]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py new file mode 100644 index 0000000000..7b50b75608 --- /dev/null +++ b/configs/segformer/segformer_mit-b4_512x512_160k_ade20k.py @@ -0,0 +1,8 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# model settings +model = dict( + pretrained='pretrain/mit_b4.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 8, 27, 3]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py b/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py new file mode 100644 index 0000000000..5212fb1f6a --- /dev/null +++ b/configs/segformer/segformer_mit-b5_512x512_160k_ade20k.py @@ -0,0 +1,8 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# model settings +model = dict( + pretrained='pretrain/mit_b5.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py b/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py new file mode 100644 index 0000000000..d21774c4d6 --- /dev/null +++ b/configs/segformer/segformer_mit-b5_640x640_160k_ade20k.py @@ -0,0 +1,44 @@ +_base_ = ['./segformer_mit-b0_512x512_160k_ade20k.py'] + +# dataset settings +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (640, 640) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 640), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 640), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) + +# model settings +model = dict( + pretrained='pretrain/mit_b5.pth', + backbone=dict( + embed_dims=64, num_heads=[1, 2, 5, 8], num_layers=[3, 6, 40, 3]), + decode_head=dict(in_channels=[64, 128, 320, 512])) diff --git a/mmseg/datasets/pipelines/transforms.py b/mmseg/datasets/pipelines/transforms.py index 1fcba69a2c..c5e94a0f14 100644 --- a/mmseg/datasets/pipelines/transforms.py +++ b/mmseg/datasets/pipelines/transforms.py @@ -6,6 +6,63 @@ from ..builder import PIPELINES +@PIPELINES.register_module() +class ResizeToMultiple(object): + """Resize images & seg to multiple of divisor. + + Args: + size_divisor (int): images and gt seg maps need to resize to multiple + of size_divisor. Default: 32. + interpolation (str, optional): The interpolation mode of image resize. + Default: None + """ + + def __init__(self, size_divisor=32, interpolation=None): + self.size_divisor = size_divisor + self.interpolation = interpolation + + def __call__(self, results): + """Call function to resize images, semantic segmentation map to + multiple of size divisor. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Resized results, 'img_shape', 'pad_shape' keys are updated. + """ + # Align image to multiple of size divisor. + img = results['img'] + img = mmcv.imresize_to_multiple( + img, + self.size_divisor, + scale_factor=1, + interpolation=self.interpolation + if self.interpolation else 'bilinear') + + results['img'] = img + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + # Align segmentation map to multiple of size divisor. + for key in results.get('seg_fields', []): + gt_seg = results[key] + gt_seg = mmcv.imresize_to_multiple( + gt_seg, + self.size_divisor, + scale_factor=1, + interpolation='nearest') + results[key] = gt_seg + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += (f'(size_divisor={self.size_divisor}, ' + f'interpolation={self.interpolation})') + return repr_str + + @PIPELINES.register_module() class Resize(object): """Resize images & seg. diff --git a/mmseg/models/backbones/mit.py b/mmseg/models/backbones/mit.py index cad0b43134..9d41ea58c1 100644 --- a/mmseg/models/backbones/mit.py +++ b/mmseg/models/backbones/mit.py @@ -11,7 +11,7 @@ from ...utils import get_root_logger from ..builder import BACKBONES -from ..utils import PatchEmbed, mit_convert, nchw_to_nlc, nlc_to_nchw +from ..utils import PatchEmbed, nchw_to_nlc, nlc_to_nchw class MixFFN(BaseModule): @@ -159,7 +159,13 @@ def forward(self, x, hw_shape, identity=None): if identity is None: identity = x_q - out = self.attn(query=x_q, key=x_kv, value=x_kv)[0] + # `need_weights=True` will let nn.MultiHeadAttention + # `return attn_output, attn_output_weights.sum(dim=1) / num_heads` + # The `attn_output_weights.sum(dim=1)` may cause cuda error. So, we set + # `need_weights=False` to ignore `attn_output_weights.sum(dim=1)`. + # This issue - `https://github.com/pytorch/pytorch/issues/37583` report + # the error that large scale tensor sum operation may cause cuda error. + out = self.attn(query=x_q, key=x_kv, value=x_kv, need_weights=False)[0] return identity + self.dropout_layer(self.proj_drop(out)) @@ -387,17 +393,9 @@ def init_weights(self): self.pretrained, logger=logger, map_location='cpu') if 'state_dict' in checkpoint: state_dict = checkpoint['state_dict'] - elif 'model' in checkpoint: - state_dict = checkpoint['model'] else: state_dict = checkpoint - if self.pretrain_style == 'official': - # Because segformer backbone is not support by mmcls, - # so we need to convert pretrain weights to match this - # implementation. - state_dict = mit_convert(state_dict) - self.load_state_dict(state_dict, False) def forward(self, x): diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index fcd0fa60bc..5b64125056 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -16,6 +16,7 @@ from .point_head import PointHead from .psa_head import PSAHead from .psp_head import PSPHead +from .segformer_head import SegformerHead from .sep_aspp_head import DepthwiseSeparableASPPHead from .sep_fcn_head import DepthwiseSeparableFCNHead from .setr_mla_head import SETRMLAHead @@ -26,5 +27,6 @@ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', - 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', 'SETRMLAHead' + 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', + 'SETRMLAHead', 'SegformerHead' ] diff --git a/mmseg/models/decode_heads/segformer_head.py b/mmseg/models/decode_heads/segformer_head.py new file mode 100644 index 0000000000..9ae1ff69d8 --- /dev/null +++ b/mmseg/models/decode_heads/segformer_head.py @@ -0,0 +1,65 @@ +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule + +from mmseg.models.builder import HEADS +from mmseg.models.decode_heads.decode_head import BaseDecodeHead +from mmseg.ops import resize + + +@HEADS.register_module() +class SegformerHead(BaseDecodeHead): + """The all mlp Head of segformer. + + This head is the implementation of + `Segformer ` _. + + Args: + interpolate_mode: The interpolate mode of MLP head upsample operation. + Default: 'bilinear'. + """ + + def __init__(self, interpolate_mode='bilinear', **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + + self.interpolate_mode = interpolate_mode + num_inputs = len(self.in_channels) + + assert num_inputs == len(self.in_index) + + self.convs = nn.ModuleList() + for i in range(num_inputs): + self.convs.append( + ConvModule( + in_channels=self.in_channels[i], + out_channels=self.channels, + kernel_size=1, + stride=1, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + + self.fusion_conv = ConvModule( + in_channels=self.channels * num_inputs, + out_channels=self.channels, + kernel_size=1, + norm_cfg=self.norm_cfg) + + def forward(self, inputs): + # Receive 4 stage backbone feature map: 1/4, 1/8, 1/16, 1/32 + inputs = self._transform_inputs(inputs) + outs = [] + for idx in range(len(inputs)): + x = inputs[idx] + conv = self.convs[idx] + outs.append( + resize( + input=conv(x), + size=inputs[0].shape[2:], + mode=self.interpolate_mode, + align_corners=self.align_corners)) + + out = self.fusion_conv(torch.cat(outs, dim=1)) + + out = self.cls_seg(out) + + return out diff --git a/mmseg/models/utils/__init__.py b/mmseg/models/utils/__init__.py index 32a953b834..6ef12bb9ba 100644 --- a/mmseg/models/utils/__init__.py +++ b/mmseg/models/utils/__init__.py @@ -1,4 +1,4 @@ -from .ckpt_convert import mit_convert, swin_convert, vit_convert +from .ckpt_convert import swin_convert, vit_convert from .embed import PatchEmbed from .inverted_residual import InvertedResidual, InvertedResidualV3 from .make_divisible import make_divisible @@ -11,5 +11,5 @@ __all__ = [ 'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual', 'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'vit_convert', - 'mit_convert', 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw' + 'swin_convert', 'PatchEmbed', 'nchw_to_nlc', 'nlc_to_nchw' ] diff --git a/mmseg/models/utils/ckpt_convert.py b/mmseg/models/utils/ckpt_convert.py index 26a1b96df9..0b1b27707d 100644 --- a/mmseg/models/utils/ckpt_convert.py +++ b/mmseg/models/utils/ckpt_convert.py @@ -1,7 +1,5 @@ from collections import OrderedDict -import torch - def swin_convert(ckpt): new_ckpt = OrderedDict() @@ -90,50 +88,3 @@ def vit_convert(ckpt): new_ckpt[new_k] = v return new_ckpt - - -def mit_convert(ckpt): - new_ckpt = OrderedDict() - # Process the concat between q linear weights and kv linear weights - for k, v in ckpt.items(): - if k.startswith('head'): - continue - elif k.startswith('patch_embed'): - stage_i = int(k.split('.')[0].replace('patch_embed', '')) - new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') - new_v = v - if 'proj.' in new_k: - new_k = new_k.replace('proj.', 'projection.') - elif k.startswith('block'): - stage_i = int(k.split('.')[0].replace('block', '')) - new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') - new_v = v - if 'attn.q.' in new_k: - sub_item_k = k.replace('q.', 'kv.') - new_k = new_k.replace('q.', 'attn.in_proj_') - new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) - elif 'attn.kv.' in new_k: - continue - elif 'attn.proj.' in new_k: - new_k = new_k.replace('proj.', 'attn.out_proj.') - elif 'attn.sr.' in new_k: - new_k = new_k.replace('sr.', 'sr.') - elif 'mlp.' in new_k: - string = f'{new_k}-' - new_k = new_k.replace('mlp.', 'ffn.layers.') - if 'fc1.weight' in new_k or 'fc2.weight' in new_k: - new_v = v.reshape((*v.shape, 1, 1)) - new_k = new_k.replace('fc1.', '0.') - new_k = new_k.replace('dwconv.dwconv.', '1.') - new_k = new_k.replace('fc2.', '4.') - string += f'{new_k} {v.shape}-{new_v.shape}' - # print(string) - elif k.startswith('norm'): - stage_i = int(k.split('.')[0].replace('norm', '')) - new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') - new_v = v - else: - new_k = k - new_v = v - new_ckpt[new_k] = new_v - return new_ckpt diff --git a/tests/test_data/test_transform.py b/tests/test_data/test_transform.py index a6417575c3..33ed4ecb14 100644 --- a/tests/test_data/test_transform.py +++ b/tests/test_data/test_transform.py @@ -10,6 +10,26 @@ from mmseg.datasets.builder import PIPELINES +def test_resize_to_multiple(): + transform = dict(type='ResizeToMultiple', size_divisor=32) + transform = build_from_cfg(transform, PIPELINES) + + img = np.random.randn(213, 232, 3) + seg = np.random.randint(0, 19, (213, 232)) + results = dict() + results['img'] = img + results['gt_semantic_seg'] = seg + results['seg_fields'] = ['gt_semantic_seg'] + results['img_shape'] = img.shape + results['pad_shape'] = img.shape + + results = transform(results) + assert results['img'].shape == (224, 256, 3) + assert results['gt_semantic_seg'].shape == (224, 256) + assert results['img_shape'] == (224, 256, 3) + assert results['pad_shape'] == (224, 256, 3) + + def test_resize(): # test assertion if img_scale is a list with pytest.raises(AssertionError): diff --git a/tests/test_models/test_heads/test_segformer_head.py b/tests/test_models/test_heads/test_segformer_head.py new file mode 100644 index 0000000000..aa8dedb1a8 --- /dev/null +++ b/tests/test_models/test_heads/test_segformer_head.py @@ -0,0 +1,39 @@ +import pytest +import torch + +from mmseg.models.decode_heads import SegformerHead + + +def test_segformer_head(): + with pytest.raises(AssertionError): + # `in_channels` must have same length as `in_index` + SegformerHead( + in_channels=(1, 2, 3), in_index=(0, 1), channels=5, num_classes=2) + + H, W = (64, 64) + in_channels = (32, 64, 160, 256) + shapes = [(H // 2**(i + 2), W // 2**(i + 2)) + for i in range(len(in_channels))] + model = SegformerHead( + in_channels=in_channels, + in_index=[0, 1, 2, 3], + channels=256, + num_classes=19) + + with pytest.raises(IndexError): + # in_index must match the input feature maps. + inputs = [ + torch.randn((1, in_channel, *shape)) + for in_channel, shape in zip(in_channels, shapes) + ][:3] + temp = model(inputs) + + # Normal Input + # ((1, 32, 16, 16), (1, 64, 8, 8), (1, 160, 4, 4), (1, 256, 2, 2) + inputs = [ + torch.randn((1, in_channel, *shape)) + for in_channel, shape in zip(in_channels, shapes) + ] + temp = model(inputs) + + assert temp.shape == (1, 19, H // 4, W // 4) diff --git a/tools/model_converters/mit_convert.py b/tools/model_converters/mit_convert.py new file mode 100644 index 0000000000..c914c4edba --- /dev/null +++ b/tools/model_converters/mit_convert.py @@ -0,0 +1,76 @@ +import argparse +from collections import OrderedDict + +import torch + + +def mit_convert(ckpt): + new_ckpt = OrderedDict() + # Process the concat between q linear weights and kv linear weights + for k, v in ckpt.items(): + if k.startswith('head'): + continue + # patch embedding convertion + elif k.startswith('patch_embed'): + stage_i = int(k.split('.')[0].replace('patch_embed', '')) + new_k = k.replace(f'patch_embed{stage_i}', f'layers.{stage_i-1}.0') + new_v = v + if 'proj.' in new_k: + new_k = new_k.replace('proj.', 'projection.') + # transformer encoder layer convertion + elif k.startswith('block'): + stage_i = int(k.split('.')[0].replace('block', '')) + new_k = k.replace(f'block{stage_i}', f'layers.{stage_i-1}.1') + new_v = v + if 'attn.q.' in new_k: + sub_item_k = k.replace('q.', 'kv.') + new_k = new_k.replace('q.', 'attn.in_proj_') + new_v = torch.cat([v, ckpt[sub_item_k]], dim=0) + elif 'attn.kv.' in new_k: + continue + elif 'attn.proj.' in new_k: + new_k = new_k.replace('proj.', 'attn.out_proj.') + elif 'attn.sr.' in new_k: + new_k = new_k.replace('sr.', 'sr.') + elif 'mlp.' in new_k: + string = f'{new_k}-' + new_k = new_k.replace('mlp.', 'ffn.layers.') + if 'fc1.weight' in new_k or 'fc2.weight' in new_k: + new_v = v.reshape((*v.shape, 1, 1)) + new_k = new_k.replace('fc1.', '0.') + new_k = new_k.replace('dwconv.dwconv.', '1.') + new_k = new_k.replace('fc2.', '4.') + string += f'{new_k} {v.shape}-{new_v.shape}' + # norm layer convertion + elif k.startswith('norm'): + stage_i = int(k.split('.')[0].replace('norm', '')) + new_k = k.replace(f'norm{stage_i}', f'layers.{stage_i-1}.2') + new_v = v + else: + new_k = k + new_v = v + new_ckpt[new_k] = new_v + return new_ckpt + + +def parse_args(): + parser = argparse.ArgumentParser( + 'Convert official segformer backbone weights to mmseg style.') + parser.add_argument( + 'src', help='Source path of official segformer backbone weights.') + parser.add_argument( + 'dst', + help='Destination path of converted segformer backbone weights.') + + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + src_path = args.src + dst_path = args.dst + + ckpt = torch.load(src_path, map_location='cpu') + + ckpt = mit_convert(ckpt) + torch.save(ckpt, dst_path)