Skip to content

[Misc]: how to fix the proper cleanup problem in tests #7053

@youkaichao

Description

@youkaichao

Anything you want to discuss about vllm.

We are having a very difficult time in cleaning up resources properly, especially in distributed inference. This makes our tests suffer recently.

To understand this, we have to understand the process model of pytest:

For this simple test:

import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import os
    print((arg, os.getpid()))

pytest will create one process, to run all the three tests one by one.

So the output is (Note the process id is the same for three tests):

testf.py::test_pass[1] (1, 15068)
PASSED
testf.py::test_pass[2] (2, 15068)
PASSED
testf.py::test_pass[3] (3, 15068)
PASSED

The fact that these three tests share the same process, makes some low-level handling difficult.

  1. when some test segfaults, the following test will not run because the process died. For example:
import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import os
    print((arg, os.getpid()))

    if arg == 2:
        import ctypes
        func_ptr = ctypes.CFUNCTYPE(ctypes.c_int)(0)
        # Calling the function pointer with an invalid address will cause a segmentation fault
        func_ptr()

this test produces:

Running 3 items in this shard: testf.py::test_pass[1], testf.py::test_pass[2], testf.py::test_pass[3]

testf.py::test_pass[1] (1, 24492)
PASSED
testf.py::test_pass[2] (2, 24492)
Fatal Python error: Segmentation fault

So test 3 is not executed in this case. Arguably, this is okay. If test 2 fails with segment fault, we should definitely investigate the reason and fix it.

  1. when some tests make the process dirty, the process cannot be used for later tests anymore. For example:
import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
def test_pass(arg):
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100

In this example, every test need a clean process ( without cuda initialized ), but it will make the process dirty after the test finishes.

The process can also be dirty, if some objects are not garbage-collected, and GPU memory is not released.

The solution is to fork a new process for every test, i.e. run the test with pytest --forked -s test.py . It mostly works, but with one caveat: output is not captured. Note that I added print(data) to print something, but pytest --forked will discard the output. This is not friendly to developers.

An alternate solution we explored, is to manually create one process for each test case:

import os
arg = int(os.environ['arg'])
def test_pass():
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100
    if arg == 2:
        raise RuntimeError("invalid arg")

And use environment variable to launch every test:

arg=1 pytest -v -s test.py
arg=2 pytest -v -s test.py
arg=3 pytest -v -s test.py

for example:

- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py

this is tedious, and does not scale when we want to test multiple combination of arguments.

The proposed solution, is to manually fork:

import functools

def fork_new_process_for_each_test(f):
    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        import os
        pid = os.fork()
        if pid == 0:
            try:
                f(*args, **kwargs)
            except Exception:
                import traceback
                traceback.print_exc()
                os._exit(1)
            else:
                os._exit(0)
        else:
            _pid, _exitcode = os.waitpid(pid, 0)
            assert _exitcode == 0, f"function {f} failed when called with args {args} and kwargs {kwargs}"
    return wrapper

import pytest
@pytest.mark.parametrize('arg', [1, 2, 3])
@fork_new_process_for_each_test
def test_pass(arg):
    import torch
    assert not torch.cuda.is_initialized()
    data = torch.ones(10, 10).cuda()
    print(data)
    assert data.sum().item() == 100
    if arg == 2:
        raise RuntimeError("invalid arg")

The output is:

================================================================== test session starts ===================================================================
platform linux -- Python 3.9.19, pytest-8.2.2, pluggy-1.5.0 -- /data/youkaichao/miniconda/envs/vllm/bin/python
cachedir: .pytest_cache
rootdir: /data/youkaichao/vllm
configfile: pyproject.toml
plugins: asyncio-0.23.7, shard-0.1.2, anyio-4.4.0, rerunfailures-14.0, forked-1.6.0
asyncio: mode=strict
collected 3 items                                                                                                                                        
Running 3 items in this shard: teste.py::test_pass[1], teste.py::test_pass[2], teste.py::test_pass[3]

teste.py::test_pass[1] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
PASSED
teste.py::test_pass[2] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
Traceback (most recent call last):
  File "/data/youkaichao/vllm/teste.py", line 10, in wrapper
    f(*args, **kwargs)
  File "/data/youkaichao/vllm/teste.py", line 32, in test_pass
    raise RuntimeError("invalid arg")
RuntimeError: invalid arg
FAILED
teste.py::test_pass[3] tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
PASSED

======================================================================== FAILURES ========================================================================
______________________________________________________________________ test_pass[2] ______________________________________________________________________

args = (), kwargs = {'arg': 2}, os = <module 'os' from '/data/youkaichao/miniconda/envs/vllm/lib/python3.9/os.py'>, pid = 2986722, _pid = 2986722
_exitcode = 256, @py_assert2 = 0, @py_assert1 = False, @py_format4 = '256 == 0'
@py_format6 = "function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}\n>assert 256 == 0"

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        import os
        pid = os.fork()
        if pid == 0:
            try:
                f(*args, **kwargs)
            except Exception:
                import traceback
                traceback.print_exc()
                os._exit(1)
            else:
                os._exit(0)
        else:
            _pid, _exitcode = os.waitpid(pid, 0)
>           assert _exitcode == 0, f"function {f} failed when called with args {args} and kwargs {kwargs}"
E           AssertionError: function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}
E           assert 256 == 0

teste.py:19: AssertionError
================================================================ short test summary info =================================================================
FAILED teste.py::test_pass[2] - AssertionError: function <function test_pass at 0x7fc791fe9ee0> failed when called with args () and kwargs {'arg': 2}
============================================================== 1 failed, 2 passed in 32.03s ==============================================================

Note that:

  • every test gets a new clean process
  • the output is captured (the print statement)
  • test 2 fails, but it does not block test 3.
  • even if test 2 fails to clean up, because it is a new process, and call os._exit(0) directly to exit, it should not affect test 3.

The only thing we need to make sure, is the process is clean when we enter the test.

Of course, the perfect solution would be to handle resource elegantly, with clear clean up. That is, very difficult when we introduce multiprocessing, ray, and asyncio. No ETA for perfect cleanup.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions