forked from open-mmlab/mmpretrain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Enhance] Add setup multi-processing both in train and test. (open-mm…
- Loading branch information
Showing
6 changed files
with
130 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,8 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .collect_env import collect_env | ||
from .logger import get_root_logger, load_json_log | ||
from .setup_env import setup_multi_processes | ||
|
||
__all__ = ['collect_env', 'get_root_logger', 'load_json_log'] | ||
__all__ = [ | ||
'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os | ||
import platform | ||
import warnings | ||
|
||
import cv2 | ||
import torch.multiprocessing as mp | ||
|
||
|
||
def setup_multi_processes(cfg): | ||
"""Setup multi-processing environment variables.""" | ||
# set multi-process start method as `fork` to speed up the training | ||
if platform.system() != 'Windows': | ||
mp_start_method = cfg.get('mp_start_method', 'fork') | ||
current_method = mp.get_start_method(allow_none=True) | ||
if current_method is not None and current_method != mp_start_method: | ||
warnings.warn( | ||
f'Multi-processing start method `{mp_start_method}` is ' | ||
f'different from the previous setting `{current_method}`.' | ||
f'It will be force set to `{mp_start_method}`. You can change ' | ||
f'this behavior by changing `mp_start_method` in your config.') | ||
mp.set_start_method(mp_start_method, force=True) | ||
|
||
# disable opencv multithreading to avoid system being overloaded | ||
opencv_num_threads = cfg.get('opencv_num_threads', 0) | ||
cv2.setNumThreads(opencv_num_threads) | ||
|
||
# setup OMP threads | ||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa | ||
if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: | ||
omp_num_threads = 1 | ||
warnings.warn( | ||
f'Setting OMP_NUM_THREADS environment variable for each process ' | ||
f'to be {omp_num_threads} in default, to avoid your system being ' | ||
f'overloaded, please further tune the variable for optimal ' | ||
f'performance in your application as needed.') | ||
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) | ||
|
||
# setup MKL threads | ||
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: | ||
mkl_num_threads = 1 | ||
warnings.warn( | ||
f'Setting MKL_NUM_THREADS environment variable for each process ' | ||
f'to be {mkl_num_threads} in default, to avoid your system being ' | ||
f'overloaded, please further tune the variable for optimal ' | ||
f'performance in your application as needed.') | ||
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import multiprocessing as mp | ||
import os | ||
import platform | ||
|
||
import cv2 | ||
from mmcv import Config | ||
|
||
from mmcls.utils import setup_multi_processes | ||
|
||
|
||
def test_setup_multi_processes(): | ||
# temp save system setting | ||
sys_start_mehod = mp.get_start_method(allow_none=True) | ||
sys_cv_threads = cv2.getNumThreads() | ||
# pop and temp save system env vars | ||
sys_omp_threads = os.environ.pop('OMP_NUM_THREADS', default=None) | ||
sys_mkl_threads = os.environ.pop('MKL_NUM_THREADS', default=None) | ||
|
||
# test config without setting env | ||
config = dict(data=dict(workers_per_gpu=2)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert os.getenv('OMP_NUM_THREADS') == '1' | ||
assert os.getenv('MKL_NUM_THREADS') == '1' | ||
# when set to 0, the num threads will be 1 | ||
assert cv2.getNumThreads() == 1 | ||
if platform.system() != 'Windows': | ||
assert mp.get_start_method() == 'fork' | ||
|
||
# test num workers <= 1 | ||
os.environ.pop('OMP_NUM_THREADS') | ||
os.environ.pop('MKL_NUM_THREADS') | ||
config = dict(data=dict(workers_per_gpu=0)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert 'OMP_NUM_THREADS' not in os.environ | ||
assert 'MKL_NUM_THREADS' not in os.environ | ||
|
||
# test manually set env var | ||
os.environ['OMP_NUM_THREADS'] = '4' | ||
config = dict(data=dict(workers_per_gpu=2)) | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert os.getenv('OMP_NUM_THREADS') == '4' | ||
|
||
# test manually set opencv threads and mp start method | ||
config = dict( | ||
data=dict(workers_per_gpu=2), | ||
opencv_num_threads=4, | ||
mp_start_method='spawn') | ||
cfg = Config(config) | ||
setup_multi_processes(cfg) | ||
assert cv2.getNumThreads() == 4 | ||
assert mp.get_start_method() == 'spawn' | ||
|
||
# revert setting to avoid affecting other programs | ||
if sys_start_mehod: | ||
mp.set_start_method(sys_start_mehod, force=True) | ||
cv2.setNumThreads(sys_cv_threads) | ||
if sys_omp_threads: | ||
os.environ['OMP_NUM_THREADS'] = sys_omp_threads | ||
else: | ||
os.environ.pop('OMP_NUM_THREADS') | ||
if sys_mkl_threads: | ||
os.environ['MKL_NUM_THREADS'] = sys_mkl_threads | ||
else: | ||
os.environ.pop('MKL_NUM_THREADS') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters