Skip to content

Commit 5b17d6e

Browse files
committed
[test] added a decorator for address already in use error with backward compatibility
1 parent 8f7ce94 commit 5b17d6e

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

colossalai/testing/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from .comparison import assert_equal, assert_not_equal, assert_close, assert_close_loose, assert_equal_in_group
2-
from .utils import parameterize, rerun_on_exception
2+
from .utils import parameterize, rerun_on_exception, rerun_if_address_is_in_use
33

44
__all__ = [
55
'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize',
6-
'rerun_on_exception'
6+
'rerun_on_exception', 'rerun_if_address_is_in_use'
77
]

colossalai/testing/utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import re
2+
import torch
23
from typing import Callable, List, Any
34
from functools import partial
45
from inspect import signature
6+
from packaging import version
57

68

79
def parameterize(argument: str, values: List[Any]) -> Callable:
@@ -144,3 +146,29 @@ def _run_until_success(*args, **kwargs):
144146
return _run_until_success
145147

146148
return _wrapper
149+
150+
151+
def rerun_if_address_is_in_use():
152+
"""
153+
This function reruns a wrapped function if "address already in use" occurs
154+
in testing spawned with torch.multiprocessing
155+
156+
Usage::
157+
158+
@rerun_if_address_is_in_use()
159+
def test_something():
160+
...
161+
162+
"""
163+
# check version
164+
torch_version = version.parse(torch.__version__)
165+
assert torch_version.major == 1
166+
167+
# only torch >= 1.8 has ProcessRaisedException
168+
if torch_version.minor >= 8:
169+
exception = torch.multiprocessing.ProcessRaisedException
170+
else:
171+
exception = RuntimeError
172+
173+
func_wrapper = rerun_on_exception(exception_type=exception, pattern=".*Address already in use.*")
174+
return func_wrapper

0 commit comments

Comments
 (0)