Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] V1: Use multiprocessing by default #11074

Merged
merged 6 commits into from
Dec 14, 2024

Conversation

russellb
Copy link
Member

@russellb russellb commented Dec 10, 2024

This PR turns on VLLM_ENABLE_V1_MULTIPROCESSING by default.

It handles the choice of Python multiprocessing method in such a way that we
try to give things the best chance of working, and clear failures if not. This
topic is discussed in detail in the included design doc.

https://github.com/russellb/vllm/blob/v1-multiproc-by-default/docs/source/design/multiprocessing.md


b30dad7 docs: Explain vllm's use of Python multiprocessing
39553b9 v1: use multiprocessing by default

commit b30dad7
Author: Russell Bryant rbryant@redhat.com
Date: Tue Dec 10 17:01:13 2024 +0000

docs: Explain vllm's use of Python multiprocessing

vLLM uses Python's `multiprocessing` library, but its use is
complicated by use of vLLM as a library and by compatibility issues
with vLLM dependencies.

This design doc:

- provides context for the topic
- reviews the current state of dealing with multiprocessing method
- proposes next steps
- discusses alternatives considered
- lists possible future work

Signed-off-by: Russell Bryant <rbryant@redhat.com>

commit 39553b9
Author: Russell Bryant rbryant@redhat.com
Date: Tue Dec 10 20:33:23 2024 +0000

v1: use multiprocessing by default

Previously, this code forced the use of the `spawn` multiprocessing
method. Since we know this causes problems in some configurations,
multiprocessing as off by default.

This change turns it on by default and makes use of existing code that
tries to choose the best multiprocessing method based on what we can
detect.

- use `fork` by default
- use `spawn` if CUDA has already been initialized, but give a warning

This same logic is already in use for spawning multiple workers for v1
tensor parallelism support.

The design doc `docs/design/multiprocessing.md` covers this topic in
more detail.

Signed-off-by: Russell Bryant <rbryant@redhat.com>

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mergify mergify bot added the documentation Improvements or additions to documentation label Dec 10, 2024
@russellb russellb force-pushed the v1-multiproc-by-default branch from 39553b9 to 9d0bc9b Compare December 10, 2024 22:16
@russellb
Copy link
Member Author

The case that's failing for me is running a simple offline inference script:

if __name__ == "__main__":

    from vllm import LLM, SamplingParams

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

    llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")

    outputs = llm.generate(prompts, sampling_params)

    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

With V1 + multiprocessing enabled, this script will hang at the end and not exit until you interrupt it with Ctrl-c. The script exits cleanly on its own without multiprocessing enabled, or when using V0.

I'm still digging into the root cause.

@russellb russellb force-pushed the v1-multiproc-by-default branch from 9d0bc9b to be8449c Compare December 10, 2024 22:27
@russellb
Copy link
Member Author

The bug I discussed above is fixed in #11076

This PR should be ready once #11076 goes in. I'll take this out of Draft status after that.

@robertgshaw2-redhat
Copy link
Collaborator

robertgshaw2-redhat commented Dec 11, 2024

One (weird) failure case where the torch.cuda.is_initalized() returns False, but when we fork, the cuda context complains

import torch
# If I run this line before importing vllm, we will fail with the forked the cuda context error
print(f"{torch.cuda.is_available()=}")
# This reports False (this is the function we use internally to detect failure
print(f"{torch.cuda.is_initialized()=}")

from vllm import LLM

model = LLM(model="Qwen/Qwen2-0.5B-Instruct",
            max_model_len=2048,
            enforce_eager=True)

out = model.generate("hi my name is")
print(out)

@russellb
Copy link
Member Author

One (weird) failure case where the torch.cuda.is_initalized() returns False, but when we fork, the cuda context complains

Cool, thanks for the reproducer. I'll take a look.

@@ -0,0 +1,185 @@
# Python Multiprocessing

*Note that source code references are to the state of the code at the time of writing in December, 2024.*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a TLDR at the top which says if you encounter the following error:

RuntimeError:
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.
        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:
            if __name__ == '__main__':
                freeze_support()
                ...
        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.
        To fix this issue, refer to the "Safe importing of main module"
        section in https://docs.python.org/3/library/multiprocessing.html

Then guard your import with if __name__ == "__main__"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the user-oriented "Debugging Tips" page with a new section on Python multiprocessing. It lists the errors / warnings someone might see and gives instructions on how to use the main guard to solve them.

I updated the dev-oriented design doc with a link at the top over to the debugging page, as well.

def _check_multiproc_method():
if (cuda_is_initialized()
and os.environ.get("VLLM_WORKER_MULTIPROC_METHOD") != "spawn"):
logger.warning("CUDA was previously initialized. We must use "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should link back to the doc telling people to guard the import with if __name__ == "__main__"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done -- added a link to the appropriate section of the debugging tips page

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally, this looks good, I like the design doc. I think we should link back to the source code.

My only concern is that because we have multiprocessing, the exception from the main process will not be immediately visible. This is a problem for all exceptions in startup

(venv) rshaw@beaker:~/vllm$ python3 test.py 
here!
torch.cuda.is_available()=True
torch.cuda.is_initialized()=False
torch.cuda.is_initialized()=True
WARNING 12-11 01:04:14 arg_utils.py:1232] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
INFO 12-11 01:04:22 config.py:405] This model supports multiple tasks: {'generate', 'embedding'}. Defaulting to 'generate'.
WARNING 12-11 01:04:22 cuda.py:98] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
WARNING 12-11 01:04:22 config.py:517] Async output processing is not supported on the current platform type cuda.
WARNING 12-11 01:04:24 multiproc_worker_utils.py:280] CUDA was previously initialized. We must use the `spawn` multiprocessing start method by setting VLLM_WORKER_MULTIPROC_METHOD to 'spawn'. This can cause issues if you do not guarding with __name__ == '__main__' See [TODO: LINK TO DOC] for workarounds.
here!
torch.cuda.is_available()=True
torch.cuda.is_initialized()=False
torch.cuda.is_initialized()=True
WARNING 12-11 01:04:31 arg_utils.py:1232] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
INFO 12-11 01:04:38 config.py:405] This model supports multiple tasks: {'embedding', 'generate'}. Defaulting to 'generate'.
WARNING 12-11 01:04:38 cuda.py:98] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
WARNING 12-11 01:04:38 config.py:517] Async output processing is not supported on the current platform type cuda.
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 122, in spawn_main
    exitcode = _main(fd, parent_sentinel)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 131, in _main
    prepare(preparation_data)
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 246, in prepare
    _fixup_main_from_path(data['init_main_from_path'])
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 297, in _fixup_main_from_path
    main_content = runpy.run_path(main_path,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<frozen runpy>", line 286, in run_path
  File "<frozen runpy>", line 98, in _run_module_code
  File "<frozen runpy>", line 88, in _run_code
  File "/home/rshaw/vllm/test.py", line 11, in <module>
    model = LLM(model="Qwen/Qwen2-0.5B-Instruct",
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/utils.py", line 1057, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/entrypoints/llm.py", line 228, in __init__
    self.llm_engine = self.engine_class.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/llm_engine.py", line 98, in from_engine_args
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/llm_engine.py", line 71, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 48, in make_client
    return SyncMPClient(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 184, in __init__
    super().__init__(*args, asyncio_mode=False, **kwargs)
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 152, in __init__
    self.proc = EngineCoreProc.make_engine_core_process(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 213, in make_engine_core_process
    proc.start()
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
                  ^^^^^^^^^^^^^^^^^
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/context.py", line 289, in _Popen
    return Popen(process_obj)
           ^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/popen_spawn_posix.py", line 42, in _launch
    prep_data = spawn.get_preparation_data(process_obj._name)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 164, in get_preparation_data
    _check_not_importing_main()
  File "/home/rshaw/.pyenv/versions/3.12.4/lib/python3.12/multiprocessing/spawn.py", line 140, in _check_not_importing_main
    raise RuntimeError('''
RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

        To fix this issue, refer to the "Safe importing of main module"
        section in https://docs.python.org/3/library/multiprocessing.html
        
Traceback (most recent call last):
  File "/home/rshaw/vllm/test.py", line 11, in <module>
    model = LLM(model="Qwen/Qwen2-0.5B-Instruct",
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/utils.py", line 1057, in inner
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/entrypoints/llm.py", line 228, in __init__
    self.llm_engine = self.engine_class.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/llm_engine.py", line 98, in from_engine_args
    return cls(vllm_config=vllm_config,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/llm_engine.py", line 71, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 48, in make_client
    return SyncMPClient(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 184, in __init__
    super().__init__(*args, asyncio_mode=False, **kwargs)
  File "/home/rshaw/vllm/vllm/v1/engine/core_client.py", line 152, in __init__
    self.proc = EngineCoreProc.make_engine_core_process(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 216, in make_engine_core_process
    EngineCoreProc.wait_for_startup(proc, ready_path)
  File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 184, in wait_for_startup
    raise e
  File "/home/rshaw/vllm/vllm/v1/engine/core.py", line 178, in wait_for_startup
    raise RuntimeError("EngineCoreProc failed to start.")
RuntimeError: EngineCoreProc failed to start.

Out of scope for this PR, but maybe something we should consider looking into (generally the strategy for propagating errors).

@russellb russellb force-pushed the v1-multiproc-by-default branch from be8449c to 74f3392 Compare December 11, 2024 14:02
@russellb
Copy link
Member Author

The bug I discussed above is fixed in #11076

This PR should be ready once #11076 goes in. I'll take this out of Draft status after that.

This PR merged and I've rebased this one on main, so I'm going to drop the draft status.

I still need to address Rob's feedback. I'll be doing that shortly.

@russellb russellb marked this pull request as ready for review December 11, 2024 14:03
@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) December 11, 2024 15:59
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Dec 11, 2024
@russellb
Copy link
Member Author

One (weird) failure case where the torch.cuda.is_initalized() returns False, but when we fork, the cuda context complains

Cool, thanks for the reproducer. I'll take a look.

For the record, I haven't been able to come up with anything that handles this case any better.

As of this PR, at least if someone goes to docs.vllm.ai and searches for the Python RuntimeError, they should find the updated docs that explain what should be changed. Hopefully this doesn't happen often! Perhaps if we see enough complaints, we can keep digging (or eventually rip out our use of multiprocessing and replace it with something else that avoids all of this!)

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work with the documentation. I agree with the tradeoffs made

Copy link
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is exciting! I left a comment and please take a look!

Copy link

mergify bot commented Dec 12, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @russellb.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 12, 2024
auto-merge was automatically disabled December 12, 2024 14:06

Head branch was pushed to by a user without write access

@WoosukKwon
Copy link
Collaborator

@russellb Amazing! Is this PR ready to go?

@russellb
Copy link
Member Author

@russellb Amazing! Is this PR ready to go?

It is to me - up to you all, though. :-)

@WoosukKwon
Copy link
Collaborator

It is to me - up to you all, though. :-)

@russellb Can you take a look at the failed tests? Thanks!

@russellb
Copy link
Member Author

It is to me - up to you all, though. :-)

@russellb Can you take a look at the failed tests? Thanks!

They didn't obviously look related to me, so I'll see if I can re-run them just in case (I have access to do that now!)

@russellb
Copy link
Member Author

It is to me - up to you all, though. :-)

@russellb Can you take a look at the failed tests? Thanks!

They didn't obviously look related to me, so I'll see if I can re-run them just in case (I have access to do that now!)

The test failures are real and related to V1 behavior with multiprocessing turned on. The PR should be held until we sort it out.

vLLM uses Python's `multiprocessing` library, but its use is
complicated by use of vLLM as a library and by compatibility issues
with vLLM dependencies.

This design doc:

- provides context for the topic
- reviews the current state of dealing with multiprocessing method
- proposes next steps
- discusses alternatives considered
- lists possible future work

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Previously, this code forced the use of the `spawn` multiprocessing
method. Since we know this causes problems in some configurations,
multiprocessing as off by default.

This change turns it on by default and makes use of existing code that
tries to choose the best multiprocessing method based on what we can
detect.

- use `fork` by default
- use `spawn` if CUDA has already been initialized, but give a warning

This same logic is already in use for spawning multiple workers for v1
tensor parallelism support.

The design doc `docs/design/multiprocessing.md` covers this topic in
more detail.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Add info to the debugging tips doc on how  to update the code to use a
`__main__` guard to avoid conflicts with vllm's use of `spawn`.

Update the log message to include a link to the debugging tips page.

Update the more detailed design doc with a reference to the debugging
tips page, as well.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
The wording about the state of v1 was changed to past tense, since
once the PR is merged, none of it will be current state anymore.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
It's possible that code doing offline inference may create an `LLM`
object, delete it, and create a new one within the same process. In
that case we need to be very explicit about ensuring the right cleanup
tasks occur rather than assuming garbage collection will result in the
right things at the right time.

The primary fix is that when `del` is called on an `LLM` object, we
explicitly call `shutdown()` on the engine if it has a shutdown
method. This ensures the v1 shutdown code runs.  I also needed to
propogate the shutdown through the v1 LLMEngine.

The v1 engine client shutdown handling needed some changes, as well:

- No longer rely on `atexit` exclusively. Unregister our atexit
  callback if shutdown occurs prior to the process exiting.

- Harden the shutdown code for when it runs in late stages of process
  shutdown. I observed cases when the final garbage collection in the
  process triggered this code, that key imports were no longer
  accessible (`atexit` and `os`), so don't try those operations if
  they are None.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
@russellb russellb force-pushed the v1-multiproc-by-default branch from fe0d80f to 36080d8 Compare December 13, 2024 21:07
@mergify mergify bot added the frontend label Dec 13, 2024
shutdown may be called explicitly if someone does `del` on their `LLM`
object. In that case, we should unregister our atexit handler. I also
adjusted the code to handle if shutdown() gets called in late stages
of process shutdown when globals have been replaced by None.

Signed-off-by: Russell Bryant <rbryant@redhat.com>
Comment on lines +164 to +166
if atexit:
# in case shutdown gets called via __del__ first
atexit.unregister(self.shutdown)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the behavior when we don't do this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • best case self.shutdown() gets called more than once, but we handle it gracefully
  • if we're doing offline inference and someone creates / deletes LLM instances multiple times, these registered handlers will prevent garbage collection of the objects. Our explicit cleanup calls should still make things work, but it's effectively a memory leak.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR looks great, and it works for me. Thanks!

@russellb
Copy link
Member Author

It is to me - up to you all, though. :-)

@russellb Can you take a look at the failed tests? Thanks!

They didn't obviously look related to me, so I'll see if I can re-run them just in case (I have access to do that now!)

The test failures are real and related to V1 behavior with multiprocessing turned on. The PR should be held until we sort it out.

I think I've got the cleanup issues sorted out. CI is green now!

@comaniac comaniac merged commit 4863e5f into vllm-project:main Dec 14, 2024
55 checks passed
BKitor pushed a commit to BKitor/vllm that referenced this pull request Dec 30, 2024
Signed-off-by: Russell Bryant <rbryant@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation frontend ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants