forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add model ensemble tools (open-mmlab#2218)
* [Feature] Add model ensemble tool * [Enhance] Add en and zh_cn instructions for model_ensemble * [Enhance] Add default-value for --out and modify instruction * [Enhance] Add arg-type for --out * [Enhance] Delete redundant code
- Loading branch information
Showing
4 changed files
with
205 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import argparse | ||
import os | ||
|
||
import mmcv | ||
import numpy as np | ||
import torch | ||
from mmcv.parallel import MMDataParallel | ||
from mmcv.parallel.scatter_gather import scatter_kwargs | ||
from mmcv.runner import load_checkpoint, wrap_fp16_model | ||
from PIL import Image | ||
|
||
from mmseg.datasets import build_dataloader, build_dataset | ||
from mmseg.models import build_segmentor | ||
|
||
|
||
@torch.no_grad() | ||
def main(args): | ||
|
||
models = [] | ||
gpu_ids = args.gpus | ||
configs = args.config | ||
ckpts = args.checkpoint | ||
|
||
cfg = mmcv.Config.fromfile(configs[0]) | ||
|
||
if args.aug_test: | ||
cfg.data.test.pipeline[1].img_ratios = [ | ||
0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0 | ||
] | ||
cfg.data.test.pipeline[1].flip = True | ||
else: | ||
cfg.data.test.pipeline[1].img_ratios = [1.0] | ||
cfg.data.test.pipeline[1].flip = False | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
# build the dataloader | ||
dataset = build_dataset(cfg.data.test) | ||
data_loader = build_dataloader( | ||
dataset, | ||
samples_per_gpu=1, | ||
workers_per_gpu=4, | ||
dist=False, | ||
shuffle=False, | ||
) | ||
|
||
for idx, (config, ckpt) in enumerate(zip(configs, ckpts)): | ||
cfg = mmcv.Config.fromfile(config) | ||
cfg.model.pretrained = None | ||
cfg.data.test.test_mode = True | ||
|
||
model = build_segmentor(cfg.model, test_cfg=cfg.get('test_cfg')) | ||
if cfg.get('fp16', None): | ||
wrap_fp16_model(model) | ||
load_checkpoint(model, ckpt, map_location='cpu') | ||
torch.cuda.empty_cache() | ||
tmpdir = args.out | ||
mmcv.mkdir_or_exist(tmpdir) | ||
model = MMDataParallel(model, device_ids=[gpu_ids[idx % len(gpu_ids)]]) | ||
model.eval() | ||
models.append(model) | ||
|
||
dataset = data_loader.dataset | ||
prog_bar = mmcv.ProgressBar(len(dataset)) | ||
loader_indices = data_loader.batch_sampler | ||
for batch_indices, data in zip(loader_indices, data_loader): | ||
result = [] | ||
|
||
for model in models: | ||
x, _ = scatter_kwargs( | ||
inputs=data, kwargs=None, target_gpus=model.device_ids) | ||
if args.aug_test: | ||
logits = model.module.aug_test_logits(**x[0]) | ||
else: | ||
logits = model.module.simple_test_logits(**x[0]) | ||
result.append(logits) | ||
|
||
result_logits = 0 | ||
for logit in result: | ||
result_logits += logit | ||
|
||
pred = result_logits.argmax(axis=1).squeeze() | ||
img_info = dataset.img_infos[batch_indices[0]] | ||
file_name = os.path.join( | ||
tmpdir, img_info['ann']['seg_map'].split(os.path.sep)[-1]) | ||
Image.fromarray(pred.astype(np.uint8)).save(file_name) | ||
prog_bar.update() | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser( | ||
description='Model Ensemble with logits result') | ||
parser.add_argument( | ||
'--config', type=str, nargs='+', help='ensemble config files path') | ||
parser.add_argument( | ||
'--checkpoint', | ||
type=str, | ||
nargs='+', | ||
help='ensemble checkpoint files path') | ||
parser.add_argument( | ||
'--aug-test', | ||
action='store_true', | ||
help='control ensemble aug-result or single-result (default)') | ||
parser.add_argument( | ||
'--out', type=str, default='results', help='the dir to save result') | ||
parser.add_argument( | ||
'--gpus', type=int, nargs='+', default=[0], help='id of gpu to use') | ||
|
||
args = parser.parse_args() | ||
assert len(args.config) == len(args.checkpoint), \ | ||
f'len(config) must equal len(checkpoint), ' \ | ||
f'but len(config) = {len(args.config)} and' \ | ||
f'len(checkpoint) = {len(args.checkpoint)}' | ||
assert args.out, "ensemble result out-dir can't be None" | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |