Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU backend and update documentation #6125

Merged
merged 36 commits into from
Jul 26, 2024

Conversation

bigPYJ1151
Copy link
Contributor

@bigPYJ1151 bigPYJ1151 commented Jul 4, 2024

This PR enabled vLLM multiprocessing in CPU backend for improving async LLM engine performance and supporting TP.

The main changes include:

  • Use utilities from vllm.executor.multiproc_worker_utils to manage workers in CPUExecutor.
  • Enable tensor parallel for the CPU bcakend.
  • Provide a new env VLLM_CPU_OMP_THREADS_BIND to simplify the OpenMP thread affinity setting, especially for multiple TP workers.
  • Transparently separate api_server process and async LLM engine process to avoid the GIL affect.
  • Update related documents.
  • Enable a E2E online serving test in CI.
PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@bigPYJ1151 bigPYJ1151 marked this pull request as draft July 4, 2024 04:11
@bigPYJ1151 bigPYJ1151 marked this pull request as ready for review July 4, 2024 08:08
@bigPYJ1151 bigPYJ1151 changed the title [CPU] Enable Multiprocessing in CPU backend and documentation update [Hardware] [Intel] Enable Multiprocessing in CPU backend and documentation update Jul 4, 2024
@bigPYJ1151 bigPYJ1151 changed the title [Hardware] [Intel] Enable Multiprocessing in CPU backend and documentation update [Hardware] [Intel] Enable Multiprocessing in CPU backend and update documentation Jul 4, 2024
@bigPYJ1151 bigPYJ1151 mentioned this pull request Jul 5, 2024
4 tasks
@bigPYJ1151 bigPYJ1151 changed the title [Hardware] [Intel] Enable Multiprocessing in CPU backend and update documentation [Hardware] [Intel] Enable Multiprocessing and tensor parallel in CPU backend and update documentation Jul 16, 2024
@bigPYJ1151
Copy link
Contributor Author

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 16, 2024
@bigPYJ1151
Copy link
Contributor Author

Hi @WoosukKwon, I think this PR is ready, would you please help to review it if you have time? Thanks!

@lhtin
Copy link

lhtin commented Jul 18, 2024

@bigPYJ1151 Hi, do you have any plan support pipeline parallel in CPU backend?

@bigPYJ1151
Copy link
Contributor Author

@lhtin Not yet. Looks like PP has no significant benefit for CPU device.

@WoosukKwon
Copy link
Collaborator

@bigPYJ1151 Sorry for the delay. I've been busy with some other tasks. I will review the PR in ~48 hours. Does this sound ok to you?

@bigPYJ1151
Copy link
Contributor Author

@WoosukKwon No problem, just review it when you are available, thanks! :)

@WoosukKwon
Copy link
Collaborator

@bigPYJ1151 Thanks for your understanding!

@WoosukKwon WoosukKwon self-assigned this Jul 18, 2024
@WoosukKwon WoosukKwon self-requested a review July 22, 2024 02:44
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigPYJ1151 Sorry for the delay and many thanks for the PR This is amazing!

I really like the overall direction and the efforts you put on this. Just left some questions and minor suggestion for improving the clarity.

docs/source/getting_started/cpu-installation.rst Outdated Show resolved Hide resolved
docs/source/getting_started/cpu-installation.rst Outdated Show resolved Hide resolved
.. code-block:: console

$ export VLLM_CPU_KVCACHE_SPACE=40
$ export VLLM_CPU_OMP_THREADS_BIND=0-29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Will the API server automatically use CPU 30 and 31?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate this simplification. However, can we further set this env variable internally in vLLM so that users don't have to care about it? Just wondering because it's still not super easy to me.

For example, users may have the following questions:

  1. Can I use this arg to control the number of CPU cores I'd like to allocate to vLLM?
  2. How does this arg relate to the vLLM performance? Allocating more CPUs will improve the performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Will the API server automatically use CPU 30 and 31?

Yes, CPU 30 and 31 are reserved for non-openMP threads (e.g., python threads, asyncio event loop, ...), and leveraged by OS scheduler automatically.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really appreciate this simplification. However, can we further set this env variable internally in vLLM so that users don't have to care about it? Just wondering because it's still not super easy to me.

For example, users may have the following questions:

  1. Can I use this arg to control the number of CPU cores I'd like to allocate to vLLM?
  2. How does this arg relate to the vLLM performance? Allocating more CPUs will improve the performance?

Yes, fully automatically setting is the best solution. It requires to detect the topology of CPU cores and memory nodes. We also want to achieve such out-of-box usage.

The VLLM_CPU_OMP_THREADS_BIND controls the openMP thread behavior of the model inference, including thread number, thread affinity (pin a inference thread on a fixed CPU core), memory allocation policy (only allocate memory from the closest memory node). We have added two performance tips about this arg for platforms with hyper-threading or multi-socket configuration.

For platforms without hyper-threading or multi-socket, allocating more CPUs for model inference will improve the performance theoretically.

csrc/cpu/utils.cpp Outdated Show resolved Hide resolved
.buildkite/run-cpu-test.sh Show resolved Hide resolved
.buildkite/run-cpu-test.sh Outdated Show resolved Hide resolved
Comment on lines 28 to 42
# online inference
docker exec cpu-test bash -c "
export VLLM_CPU_KVCACHE_SPACE=10
export VLLM_CPU_OMP_THREADS_BIND=48-92
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
wget -q https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
python3 benchmarks/benchmark_serving.py \
--backend vllm \
--dataset-name sharegpt \
--dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
--model facebook/opt-125m \
--num-prompts 20 \
--endpoint /v1/completions \
--tokenizer facebook/opt-125m"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit too much for the CI test, since it's essentially benchmarking. How long does this take?

Actually, we have a simpler and shorter online serving test in https://github.com/vllm-project/vllm/blob/main/.buildkite/run-neuron-test.sh Can we do a similar thing to this instead?

Copy link
Contributor Author

@bigPYJ1151 bigPYJ1151 Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the CI log, it will take ~20 seconds (~5s for service setup and ~10s for prompts processing), so I think it is acceptable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes I just realized that num-prompts is only 20. But I think it takes time to download and process the ShareGPT dataset. What about using --random instead? This will generate the random inputs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good hint👍 Replaced with random dataset.

vllm/worker/cpu_model_runner.py Show resolved Hide resolved
vllm/distributed/parallel_state.py Show resolved Hide resolved
vllm/envs.py Outdated Show resolved Hide resolved
@bigPYJ1151
Copy link
Contributor Author

Hi @WoosukKwon , thanks for your comments! I have updated them, please check, thanks!

@zhouyuan
Copy link
Contributor

@bigPYJ1151 since the mp based backend is enabled, shall we also update the assert here:

"Distributed execution is not supported with the CPU backend.")

@WoosukKwon
Copy link
Collaborator

@bigPYJ1151 The PR looks good to me overall. Can you fix the failed CI test? Thanks!

@bigPYJ1151
Copy link
Contributor Author

Hi @WoosukKwon, thanks for your feedback! The CPU CI is fixed. The CUDA CI should pass in general but some cases may fail due to random connection errors.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bigPYJ1151 LGTM! Thanks a lot for the PR and for addressing my comments!

@WoosukKwon WoosukKwon merged commit 3bbb493 into vllm-project:main Jul 26, 2024
72 checks passed
cadedaniel pushed a commit to cadedaniel/vllm-public that referenced this pull request Jul 27, 2024
@DamonFool
Copy link
Contributor

Hi @bigPYJ1151 , the CPU target fails to build on my machine.
Could you please take a look at this #6931 ?
Thanks.

@@ -2,6 +2,6 @@
-r requirements-common.txt

# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
torch == 2.4.0; platform_machine != "ppc64le"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bigPYJ1151 and @WoosukKwon , this change missed the suffix +cpu for torch version which leads to the build failure of the CPU target.
Please take a look at #6931 .
Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also does something more subtle: the required version in requirements-build.txt and pyproject.toml is still 2.3.1, causing building this with pip install . (i.e. pep517 style builds) result in a broken build:

WARNING 08-01 17:51:34 _custom_ops.py:14] Failed to import from vllm._C with ImportError('/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_C.abi3.so: undefined symbol: torch::jit::parseSchema(std::basic_string<char, std::char_traits<char>, std::allocator<char> > const&)')
...

Querying the engine then results in an error: AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'.

Full traceback follows:

(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Error in calling custom op reshape_and_cache: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Possibly you have built or installed an obsolete version of vllm.
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 _custom_ops.py:39] Please try a clean build and install of vllm,or remove old built files such as vllm/*cpython*.so and build/ .
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226] Exception in worker VllmWorkerProcess while processing method execute_model: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache', Traceback (most recent call last):
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 223, in _run_worker_process
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     output = executor(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/worker/worker_base.py", line 273, in execute_model
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     output = self.model_runner.execute_model(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return func(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/worker/cpu_model_runner.py", line 374, in execute_model
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = model_executable(**execute_model_kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 322, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = self.model(input_ids, positions, kv_caches,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 291, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self.decoder(input_ids,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 260, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = layer(hidden_states, kv_caches[i], attn_metadata)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 162, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     hidden_states = self.self_attn(hidden_states=hidden_states,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/model_executor/models/opt.py", line 105, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self._call_impl(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return forward_call(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/layer.py", line 97, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return self.impl.forward(query,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/backends/torch_sdpa.py", line 177, in forward
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     PagedAttention.write_to_paged_cache(key, value, key_cache,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/attention/ops/paged_attn.py", line 75, in write_to_paged_cache
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     ops.reshape_and_cache(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 40, in wrapper
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     raise e
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 31, in wrapper
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     return fn(*args, **kwargs)
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]            ^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/_custom_ops.py", line 425, in reshape_and_cache
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/torch/_ops.py", line 1170, in __getattr__
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]     raise AttributeError(
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226] AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
(VllmWorkerProcess pid=2204235) ERROR 08-01 17:51:47 multiproc_worker_utils.py:226]
ERROR 08-01 17:51:47 async_llm_engine.py:56] Engine background task failed
ERROR 08-01 17:51:47 async_llm_engine.py:56] Traceback (most recent call last):
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 46, in _log_task_completion
ERROR 08-01 17:51:47 async_llm_engine.py:56]     return_value = task.result()
ERROR 08-01 17:51:47 async_llm_engine.py:56]                    ^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 641, in run_engine_loop
ERROR 08-01 17:51:47 async_llm_engine.py:56]     result = task.result()
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 584, in engine_step
ERROR 08-01 17:51:47 async_llm_engine.py:56]     request_outputs = await self.engine.step_async(virtual_engine)
ERROR 08-01 17:51:47 async_llm_engine.py:56]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py", line 253, in step_async
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = await self.model_executor.execute_model_async(
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 305, in execute_model_async
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = await make_async(self.execute_model
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/.pyenv/versions/3.11.9/lib/python3.11/concurrent/futures/thread.py", line 58, in run
ERROR 08-01 17:51:47 async_llm_engine.py:56]     result = self.fn(*self.args, **self.kwargs)
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 223, in execute_model
ERROR 08-01 17:51:47 async_llm_engine.py:56]     output = self.driver_method_invoker(self.driver_worker,
ERROR 08-01 17:51:47 async_llm_engine.py:56]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/cpu_executor.py", line 362, in _async_driver_method_invoker
ERROR 08-01 17:51:47 async_llm_engine.py:56]     return driver.execute_method(method, *args, **kwargs).get()
ERROR 08-01 17:51:47 async_llm_engine.py:56]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 08-01 17:51:47 async_llm_engine.py:56]   File "/home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/executor/multiproc_worker_utils.py", line 58, in get
ERROR 08-01 17:51:47 async_llm_engine.py:56]     raise self.result.exception
ERROR 08-01 17:51:47 async_llm_engine.py:56] AttributeError: '_OpNamespace' '_C_cache_ops' object has no attribute 'reshape_and_cache'
Exception in callback _log_task_completion(error_callback=>)() at /home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py:36
handle: >)() at /home/dtrifiro/work/vllm/tmpvenv/lib/python3.11/site-packages/vllm/engine/async_llm_engine.py:36>

kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…backend and update documentation (vllm-project#6125)

Signed-off-by: Alvant <alvasian@yandex.ru>
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed x86 CPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants