Skip to content

Commit c90b605

Browse files
authored
[Enhance] Add setup multi-processing both in train and test. (open-mmlab#707)
* Faster training * Add setup multi-processing both in train and test.
1 parent 6163f4c commit c90b605

File tree

5 files changed

+127
-2
lines changed

5 files changed

+127
-2
lines changed

mmedit/utils/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .logger import get_root_logger
3+
from .setup_env import setup_multi_processes
34

4-
__all__ = ['get_root_logger']
5+
__all__ = ['get_root_logger', setup_multi_processes]

mmedit/utils/setup_env.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os
3+
import platform
4+
import warnings
5+
6+
import cv2
7+
import torch.multiprocessing as mp
8+
9+
10+
def setup_multi_processes(cfg):
11+
"""Setup multi-processing environment variables."""
12+
# set multi-process start method as `fork` to speed up the training
13+
if platform.system() != 'Windows':
14+
mp_start_method = cfg.get('mp_start_method', 'fork')
15+
current_method = mp.get_start_method(allow_none=True)
16+
if current_method is not None and current_method != mp_start_method:
17+
warnings.warn(
18+
f'Multi-processing start method `{mp_start_method}` is '
19+
f'different from the previous setting `{current_method}`.'
20+
f'It will be force set to `{mp_start_method}`. You can change '
21+
f'this behavior by changing `mp_start_method` in your config.')
22+
mp.set_start_method(mp_start_method, force=True)
23+
24+
# disable opencv multithreading to avoid system being overloaded
25+
opencv_num_threads = cfg.get('opencv_num_threads', 0)
26+
cv2.setNumThreads(opencv_num_threads)
27+
28+
# setup OMP threads
29+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
30+
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
31+
omp_num_threads = 1
32+
warnings.warn(
33+
f'Setting OMP_NUM_THREADS environment variable for each process '
34+
f'to be {omp_num_threads} in default, to avoid your system being '
35+
f'overloaded, please further tune the variable for optimal '
36+
f'performance in your application as needed.')
37+
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
38+
39+
# setup MKL threads
40+
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
41+
mkl_num_threads = 1
42+
warnings.warn(
43+
f'Setting MKL_NUM_THREADS environment variable for each process '
44+
f'to be {mkl_num_threads} in default, to avoid your system being '
45+
f'overloaded, please further tune the variable for optimal '
46+
f'performance in your application as needed.')
47+
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)

tests/test_utils/test_setup_env.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import multiprocessing as mp
3+
import os
4+
import platform
5+
6+
import cv2
7+
from mmcv import Config
8+
9+
from mmedit.utils import setup_multi_processes
10+
11+
12+
def test_setup_multi_processes():
13+
# temp save system setting
14+
sys_start_mehod = mp.get_start_method(allow_none=True)
15+
sys_cv_threads = cv2.getNumThreads()
16+
# pop and temp save system env vars
17+
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None)
18+
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None)
19+
20+
# test config without setting env
21+
config = dict(data=dict(workers_per_gpu=2))
22+
cfg = Config(config)
23+
setup_multi_processes(cfg)
24+
assert os.getenv('OMP_NUM_THREADS') == '1'
25+
assert os.getenv('MKL_NUM_THREADS') == '1'
26+
# when set to 0, the num threads will be 1
27+
assert cv2.getNumThreads() == 1
28+
if platform.system() != 'Windows':
29+
assert mp.get_start_method() == 'fork'
30+
31+
# test num workers <= 1
32+
os.environ.pop('OMP_NUM_THREADS')
33+
os.environ.pop('MKL_NUM_THREADS')
34+
config = dict(data=dict(workers_per_gpu=0))
35+
cfg = Config(config)
36+
setup_multi_processes(cfg)
37+
assert 'OMP_NUM_THREADS' not in os.environ
38+
assert 'MKL_NUM_THREADS' not in os.environ
39+
40+
# test manually set env var
41+
os.environ['OMP_NUM_THREADS'] = '4'
42+
config = dict(data=dict(workers_per_gpu=2))
43+
cfg = Config(config)
44+
setup_multi_processes(cfg)
45+
assert os.getenv('OMP_NUM_THREADS') == '4'
46+
47+
# test manually set opencv threads and mp start method
48+
config = dict(
49+
data=dict(workers_per_gpu=2),
50+
opencv_num_threads=4,
51+
mp_start_method='spawn')
52+
cfg = Config(config)
53+
setup_multi_processes(cfg)
54+
assert cv2.getNumThreads() == 4
55+
assert mp.get_start_method() == 'spawn'
56+
57+
# revert setting to avoid affecting other programs
58+
if sys_start_mehod:
59+
mp.set_start_method(sys_start_mehod, force=True)
60+
cv2.setNumThreads(sys_cv_threads)
61+
if sys_omp_threads:
62+
os.environ['OMP_NUM_THREADS'] = sys_omp_threads
63+
else:
64+
os.environ.pop('OMP_NUM_THREADS')
65+
if sys_mkl_threads:
66+
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads
67+
else:
68+
os.environ.pop('MKL_NUM_THREADS')

tools/test.py

+5
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper
1212
from mmedit.datasets import build_dataloader, build_dataset
1313
from mmedit.models import build_model
14+
from mmedit.utils import setup_multi_processes
1415

1516

1617
def parse_args():
@@ -49,6 +50,10 @@ def main():
4950
args = parse_args()
5051

5152
cfg = mmcv.Config.fromfile(args.config)
53+
54+
# set multi-process settings
55+
setup_multi_processes(cfg)
56+
5257
# set cudnn_benchmark
5358
if cfg.get('cudnn_benchmark', False):
5459
torch.backends.cudnn.benchmark = True

tools/train.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mmedit.apis import set_random_seed, train_model
1515
from mmedit.datasets import build_dataset
1616
from mmedit.models import build_model
17-
from mmedit.utils import collect_env, get_root_logger
17+
from mmedit.utils import collect_env, get_root_logger, setup_multi_processes
1818

1919

2020
def parse_args():
@@ -59,6 +59,10 @@ def main():
5959
args = parse_args()
6060

6161
cfg = Config.fromfile(args.config)
62+
63+
# set multi-process settings
64+
setup_multi_processes(cfg)
65+
6266
# set cudnn_benchmark
6367
if cfg.get('cudnn_benchmark', False):
6468
torch.backends.cudnn.benchmark = True

0 commit comments

Comments
 (0)