Skip to content

Commit

Permalink
[Feature] Add model ensemble tools (open-mmlab#2218)
Browse files Browse the repository at this point in the history
* [Feature] Add model ensemble tool

* [Enhance] Add en and zh_cn instructions for model_ensemble

* [Enhance] Add default-value for --out and modify instruction

* [Enhance] Add arg-type for --out

* [Enhance] Delete redundant code
  • Loading branch information
zhijiejia authored Oct 24, 2022
1 parent 76a5138 commit 8dbbdd8
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 0 deletions.
29 changes: 29 additions & 0 deletions docs/en/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -424,3 +424,32 @@ result/pred_result.pkl \
result/confusion_matrix \
--show
```
## Model ensemble
To complete the integration of prediction probabilities for multiple models, we provide 'tools/model_ensemble.py'
### Usage
```bash
python tools/model_ensemble.py \
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
--aug-test \
--out ${OUTPUT_DIR}\
--gpus ${GPU_USED}\
```
### Description of all arguments
- `--config`: Path to the config file for the ensemble model
- `--checkpoint`: Path to the checkpoint file for the ensemble model
- `--aug-test`: Whether to use flip and multi-scale test
- `--out`: Save folder for model ensemble results
- `--gpus`: Gpu-id used for model ensemble
### Result of model ensemble
- The model ensemble will generate an unrendered segmentation mask for each input, the input shape is `[H, W]`, the segmentation mask shape is `[H, W]`, and each pixel-value in the segmentation mask represents the pixel category after segmentation at that position.
- The filename of the model ensemble result will be named in the same filename as `Ground Truth`. If the filename of `Ground Truth` is called `1.png`, the model ensemble result file will also be named `1.png` and placed in the folder specified by `--out`.
28 changes: 28 additions & 0 deletions docs/zh_cn/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,31 @@ configs/fcn/fcn_r50-d8_512x1024_40k_cityscapes.py \
checkpoint/fcn_r50-d8_512x1024_40k_cityscapes_20200604_192608-efe53f0d.pth \
fcn
```
## 模型集成
我们提供了`tools/model_ensemble.py` 完成对多个模型的预测概率进行集成的脚本
### 使用方法
```bash
python tools/model_ensemble.py \
--config ${CONFIG_FILE1} ${CONFIG_FILE2} ... \
--checkpoint ${CHECKPOINT_FILE1} ${CHECKPOINT_FILE2} ...\
--aug-test \
--out ${OUTPUT_DIR}\
--gpus ${GPU_USED}\
```
### 各个参数的描述:
- `--config`: 集成模型的配置文件的路径
- `--checkpoint`: 集成模型的权重文件的路径
- `--aug-test`: 是否使用翻转和多尺度预测
- `--out`: 模型集成结果的保存文件夹路径
- `--gpus`: 模型集成使用的gpu-id
### 模型集成结果
- 模型集成会对每一张输入,形状为`[H, W]`,产生一张未渲染的分割掩膜文件(segmentation mask),形状为`[H, W]`,分割掩膜中的每个像素点的值代表该位置分割后的像素类别.
- 模型集成结果的文件名会采用和`Ground Truth`一致的文件命名,如`Ground Truth`文件名称为`1.png`,则模型集成结果文件也会被命名为`1.png`,并放置在`--out`指定的文件夹中.
27 changes: 27 additions & 0 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ def simple_test(self, img, img_meta, rescale=True):
seg_pred = list(seg_pred)
return seg_pred

def simple_test_logits(self, img, img_metas, rescale=True):
"""Test without augmentations.
Return numpy seg_map logits.
"""
seg_logit = self.inference(img[0], img_metas[0], rescale)
seg_logit = seg_logit.cpu().numpy()
return seg_logit

def aug_test(self, imgs, img_metas, rescale=True):
"""Test with augmentations.
Expand All @@ -300,3 +309,21 @@ def aug_test(self, imgs, img_metas, rescale=True):
# unravel batch dim
seg_pred = list(seg_pred)
return seg_pred

def aug_test_logits(self, img, img_metas, rescale=True):
"""Test with augmentations.
Return seg_map logits. Only rescale=True is supported.
"""
# aug_test rescale all imgs back to ori_shape for now
assert rescale

imgs = img
seg_logit = self.inference(imgs[0], img_metas[0], rescale)
for i in range(1, len(imgs)):
cur_seg_logit = self.inference(imgs[i], img_metas[i], rescale)
seg_logit += cur_seg_logit

seg_logit /= len(imgs)
seg_logit = seg_logit.cpu().numpy()
return seg_logit
121 changes: 121 additions & 0 deletions tools/model_ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os

import mmcv
import numpy as np
import torch
from mmcv.parallel import MMDataParallel
from mmcv.parallel.scatter_gather import scatter_kwargs
from mmcv.runner import load_checkpoint, wrap_fp16_model
from PIL import Image

from mmseg.datasets import build_dataloader, build_dataset
from mmseg.models import build_segmentor


@torch.no_grad()
def main(args):

models = []
gpu_ids = args.gpus
configs = args.config
ckpts = args.checkpoint

cfg = mmcv.Config.fromfile(configs[0])

if args.aug_test:
cfg.data.test.pipeline[1].img_ratios = [
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0
]
cfg.data.test.pipeline[1].flip = True
else:
cfg.data.test.pipeline[1].img_ratios = [1.0]
cfg.data.test.pipeline[1].flip = False

torch.backends.cudnn.benchmark = True

# build the dataloader
dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
dataset,
samples_per_gpu=1,
workers_per_gpu=4,
dist=False,
shuffle=False,
)

for idx, (config, ckpt) in enumerate(zip(configs, ckpts)):
cfg = mmcv.Config.fromfile(config)
cfg.model.pretrained = None
cfg.data.test.test_mode = True

model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg'))
if cfg.get('fp16', None):
wrap_fp16_model(model)
load_checkpoint(model, ckpt, map_location='cpu')
torch.cuda.empty_cache()
tmpdir = args.out
mmcv.mkdir_or_exist(tmpdir)
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]])
model.eval()
models.append(model)

dataset = data_loader.dataset
prog_bar = mmcv.ProgressBar(len(dataset))
loader_indices = data_loader.batch_sampler
for batch_indices, data in zip(loader_indices, data_loader):
result = []

for model in models:
x, _ = scatter_kwargs(
inputs=data, kwargs=None, target_gpus=model.device_ids)
if args.aug_test:
logits = model.module.aug_test_logits(**x[0])
else:
logits = model.module.simple_test_logits(**x[0])
result.append(logits)

result_logits = 0
for logit in result:
result_logits += logit

pred = result_logits.argmax(axis=1).squeeze()
img_info = dataset.img_infos[batch_indices[0]]
file_name = os.path.join(
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1])
Image.fromarray(pred.astype(np.uint8)).save(file_name)
prog_bar.update()


def parse_args():
parser = argparse.ArgumentParser(
description='Model Ensemble with logits result')
parser.add_argument(
'--config', type=str, nargs='+', help='ensemble config files path')
parser.add_argument(
'--checkpoint',
type=str,
nargs='+',
help='ensemble checkpoint files path')
parser.add_argument(
'--aug-test',
action='store_true',
help='control ensemble aug-result or single-result (default)')
parser.add_argument(
'--out', type=str, default='results', help='the dir to save result')
parser.add_argument(
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use')

args = parser.parse_args()
assert len(args.config) == len(args.checkpoint), \
f'len(config) must equal len(checkpoint), ' \
f'but len(config) = {len(args.config)} and' \
f'len(checkpoint) = {len(args.checkpoint)}'
assert args.out, "ensemble result out-dir can't be None"
return args


if __name__ == '__main__':
args = parse_args()
main(args)

0 comments on commit 8dbbdd8

Please sign in to comment.