Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Prototype of vLLM execution on CPU-only devices #1028

Closed
wants to merge 3 commits into from

Conversation

bigPYJ1151
Copy link
Contributor

@bigPYJ1151 bigPYJ1151 commented Sep 13, 2023

Hi, vLLM genius @WoosukKwon @zhuohan123. Motivated by some requirements to execute vLLM on the CPU (e.g., #176 ), we recently implemented an initial prototype for CPU-only execution on the x86 CPU platform.

What we have right now:

  • Minimize changes on vLLM core components to support CPU execution including:
    • Introduced a new configuration argument device ('cuda' or 'cpu', 'cuda' by default) to specify the main device to execute vLLM.
    • Replaced the hard coding device assignments (e.g., .cuda()) to .to(device=device), or with the context set_default_device, to support vLLM execution on different device types.
    • Modified CacheEngine to allocate blocks from the CPU cache Tensor (used for swapping originally) under CPU-only mode. The size of the CPU cache can be specified with --swap-space.
  • Supporting FP32 and BF16 data type.
  • Native operators implemented for x86 CPU using AVX512_BF16 inst. set.
  • Operator dispatcher based on the device type of input Tensors.
  • Integration with the existing build script:
    • The building of CPU operators is controlled by an env VLLM_BUILD_CPU_OPS, which is disabled by default.
    • Due to the compatibility of CUDA, the CPU operators can only use gcc-12 and g++-12 to support AVX512_BF16 inst. set.

Install Instruction

  • Make sure the default version of gcc/g++ is 12
  • Install the PyTorch with pip install torch==2.1.2+cpu --index-url https://download.pytorch.org/whl/cpu
  • Build the source with VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=8 pip install --no-build-isolation -v -e .

Known Limits:

  • Tensor parallelism is not supported right now.
  • FP16 is not fully supported due to the inst. set limits.
  • Quantization is not supported right now.
  • Sliding window attention is not verified right now.

Model Support:

  • We only verified LlamaForCausalLM, MistralForCausalLM, OPTForCausalLM related models currently.
  • Ideally, this implementation can support all implemented models without the modification of model definitions.

Performance
We used the following commands to evaluate the performance with vicuna-7b-v1.5 on Intel (R) Xeon (R) CPU Max 9462 platform with 32 physical cores:

OMP_NUM_THREADS=32 numactl --physcpubind=0-31 --membind=0 python benchmark_throughput.py --backend=vllm --dataset=/root/ShareGPT_V3_unfiltered_cleaned_split.json --model=/root/vicuna-7b-v1.5/ --n=1 --num-prompts=1000 --dtype=bfloat16 --trust-remote-code --device=cpu --swap-space=40

The implementation achieved good throughput on the CPU platform:
Throughput: 0.76 requests/s, 358.22 tokens/s
Throughput: 1.00 requests/s, 479.15 tokens/s

The performance still has much improvement space, and we will optimize the performance and add remaining features continuously, hoping to be helpful for the users want to deploy vLLM on the CPU.

Please help to review the code and welcome any feedbacks, thanks!

@JasmondL
Copy link
Contributor

Hi do you intent to expose this feature on api_server.py ? from here

@maktukmak
Copy link

Hi, is this feature going to merge soon? Or is there any problem preventing that? Please let me know about the current status.

@bigPYJ1151 bigPYJ1151 force-pushed the PR_Branch branch 5 times, most recently from 0286d57 to 960ef34 Compare November 29, 2023 07:56
@bigPYJ1151 bigPYJ1151 marked this pull request as ready for review November 29, 2023 09:27
@bigPYJ1151 bigPYJ1151 changed the title Prototype of CPU-only execution [Feature] Prototype of vLLM execution on CPU-only devices Nov 30, 2023
@derekelewis
Copy link

derekelewis commented Nov 30, 2023

@bigPYJ1151 This is great - tried running the benchmark w/ llama-2-7b and no issues; however, with mistral-7b, I am getting an assertion when running the benchmark:

python benchmark_throughput.py --backend=vllm --dataset=ShareGPT_V3_unfiltered_cleaned_split.json --model=mistralai/Mistral-7B-v0.1 --n=1 --num-prompts=1000 --dtype=bfloat16 --trust-remote-code --device=cpu

I did notice that CPU blocks is 2048 on mistral-7b, but 512 on llama-2-7b and vicuna-7b-v1.5.

INFO 11-30 05:13:56 llm_engine.py:219] # GPU blocks: 0, # CPU blocks: 2048
Processed prompts:   0%|                               | 0/1000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/ec2-user/vllm/benchmarks/benchmark_throughput.py", line 314, in <module>
    main(args)
  File "/home/ec2-user/vllm/benchmarks/benchmark_throughput.py", line 200, in main
    elapsed_time = run_vllm(requests, args.model, args.tokenizer,
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/vllm/benchmarks/benchmark_throughput.py", line 107, in run_vllm
    llm._run_engine(use_tqdm=True)
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/entrypoints/llm.py", line 173, in _run_engine
    step_outputs = self.llm_engine.step()
                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/engine/llm_engine.py", line 575, in step
    output = self._run_workers(
             ^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/engine/llm_engine.py", line 738, in _run_workers
    self._run_workers_in_batch(workers, method, *args, **kwargs))
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/engine/llm_engine.py", line 712, in _run_workers_in_batch
    output = executor(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/worker/worker.py", line 389, in execute_model
    output = self.model(
             ^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/models/mistral.py", line 283, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/models/mistral.py", line 249, in forward
    hidden_states, residual = layer(
                              ^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/models/mistral.py", line 199, in forward
    hidden_states = self.self_attn(
                    ^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/models/mistral.py", line 147, in forward
    attn_output = self.attn(positions, q, k, v, k_cache, v_cache,
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/layers/attention.py", line 370, in forward
    return super().forward(
           ^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/layers/attention.py", line 264, in forward
    self.set_attn_bias(input_metadata, dtype=query.dtype)
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/layers/attention.py", line 76, in set_attn_bias
    assert not self.cpu_only
AssertionError
double free or corruption (!prev)
Aborted (core dumped)

@bigPYJ1151
Copy link
Contributor Author

@derekelewis Thanks for your attention! Mistral is using Sliding Window Attention, which we haven't adopted and verified. Removing the assert statement in set_attn_bias may make the code runnable.

For CPU cache size, it can be specified by --swap-space with the unit GB.

@derekelewis
Copy link

@bigPYJ1151 again, thanks for the contribution and that was helpful. Still exciting this works with llama-2 models. FYI, I did remove the assert and got this with mistral-7b:

File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/layers/attention.py", line 265, in forward
    self.multi_query_kv_attention(
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/vllm-0.2.2+cu123-py3.11-linux-x86_64.egg/vllm/model_executor/layers/attention.py", line 124, in multi_query_kv_attention
    ) if not self.cpu_only else torch.nn.functional.scaled_dot_product_attention(
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ec2-user/miniconda3/envs/vllmenv/lib/python3.11/site-packages/torch/utils/_device.py", line 77, in __torch_function__
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (4) must match the size of tensor b (747) at non-singleton dimension 3
double free or corruption (!prev)
Aborted (core dumped)

@Deepansharora27
Copy link

Deepansharora27 commented Dec 14, 2023

I have been getting the Following Error which asks me that is Needs CUDA for the Build Process to Succeed:

Screenshot 2023-12-14 at 5 42 02 PM

Any Workarounds regarding this ?

@Deepansharora27
Copy link

Deepansharora27 commented Dec 14, 2023

Also, If I try to run the Benchmark Script itself; it does not recognise the --device argument:

Screenshot 2023-12-14 at 8 45 14 PM

@Deepansharora27
Copy link

Another Question that I have is that is this prototype just for Benchmarking for a CPU Based Device OR Can we build your Repo from Source and perform an actual inference using a Llama Model ? @bigPYJ1151

@bigPYJ1151
Copy link
Contributor Author

@Deepansharora27 Thanks for your attention!

  • For the first one, we have updated a version to support building without CUDA runtime, by setting an environment variable VLLM_BUILD_CPU_ONLY=1.
  • For the second, it seems you are not using the branch of this PR, please check it.
  • For the last one, we haven't exposed related arguments in the vLLM serving scripts, but actually you can use vLLM to do model serving on the CPU with this PR.

@Deepansharora27
Copy link

Deepansharora27 commented Dec 19, 2023

@bigPYJ1151
Hi, I have been trying to build this Repo with your above Suggestions but still not able to build it. Here is the Error Trace that I am getting:

`/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec8::FP32Vec8(__m128bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:151:39: error: ‘_mm256_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm256_cvtph_ps’?
151 | explicit FP32Vec8(__m128bh v) : reg(_mm256_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm256_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec16::FP32Vec16(__m256bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:207:40: error: ‘_mm512_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm512_cvtph_ps’?
207 | explicit FP32Vec16(__m256bh v) : reg(_mm512_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm512_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In function ‘void vec_op::storeFP32ToT(float, T*) [with T = c10::BFloat16]’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:21: error: ‘__bfloat16’ does not name a type; did you mean ‘__float80’?
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~
| __float80
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘>’ before ‘
’ token
268 | reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘(’ before ‘
’ token
268 | reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
| (
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:33: error: expected primary-expression before ‘>’ token
268 | reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:42: error: ‘_mm_cvtness_sbh’ was not declared in this scope; did you mean ‘_mm_cvtneps_pbh’?
268 | reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~~~~~~
| _mm_cvtneps_pbh
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:60: error: expected ‘)’ before ‘;’ token
268 | reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
| )
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp: In instantiation of ‘void {anonymous}::rotary_embedding_impl(const int64_t
, scalar_t
, scalar_t
, const scalar_t
, int, int, int, int, int, int, int) [with scalar_t = float; int64_t = long int]’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp:107:3: required from here
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp:22:17: warning: unused variable ‘ELEM_SIZE’ [-Wunused-variable]
22 | constexpr int ELEM_SIZE = sizeof(scalar_t);
| ^~~~~~~~~
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp: In instantiation of ‘void {anonymous}::rotary_embedding_impl(const int64_t
, scalar_t
, scalar_t
, const scalar_t
, int, int, int, int, int, int, int) [with scalar_t = c10::BFloat16; int64_t = long int]’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp:107:3: required from here
/home/deepanshu/ai-tooling/vllm/csrc/cpu/pos_encoding_impl.cpp:22:17: warning: unused variable ‘ELEM_SIZE’ [-Wunused-variable]
[4/6] c++ -MMD -MF /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/cache_impl.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/TH -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp -o /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/cache_impl.o -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 -DVLLM_BUILD_CPU_ONLY -DVLLM_BUILD_CPU_OPS -fopenmp -mavx512f -mavx512bf16 -mavx512vl -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
FAILED: /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/cache_impl.o
c++ -MMD -MF /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/cache_impl.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/TH -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp -o /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/cache_impl.o -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 -DVLLM_BUILD_CPU_ONLY -DVLLM_BUILD_CPU_OPS -fopenmp -mavx512f -mavx512bf16 -mavx512vl -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
In file included from /home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp:4:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec8::FP32Vec8(__m128bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:151:39: error: ‘_mm256_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm256_cvtph_ps’?
151 | explicit FP32Vec8(__m128bh v) : reg(_mm256_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm256_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec16::FP32Vec16(__m256bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:207:40: error: ‘_mm512_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm512_cvtph_ps’?
207 | explicit FP32Vec16(__m256bh v) : reg(_mm512_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm512_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In function ‘void vec_op::storeFP32ToT(float, T
) [with T = c10::BFloat16]’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:21: error: ‘__bfloat16’ does not name a type; did you mean ‘__float80’?
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~
| __float80
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘>’ before ‘
’ token
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘(’ before ‘
’ token
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^
| (
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:33: error: expected primary-expression before ‘>’ token
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:42: error: ‘_mm_cvtness_sbh’ was not declared in this scope; did you mean ‘_mm_cvtneps_pbh’?
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~~~~~~
| _mm_cvtneps_pbh
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:60: error: expected ‘)’ before ‘;’ token
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
| )
In file included from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/c10/util/Exception.h:4,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/c10/core/Device.h:5,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/ATen/core/TensorBody.h:11,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/ATen/core/Tensor.h:3,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/ATen/Tensor.h:3,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/variable.h:6,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/autograd/autograd.h:3,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
from /tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/extension.h:5,
from /home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:6,
from /home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp:4:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp: In function ‘void copy_blocks_cpu(std::vectorat::Tensor&, std::vectorat::Tensor&, const std::map<long int, std::vector >&)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp:89:26: warning: comparison of integer expressions of different signedness: ‘int’ and ‘std::vectorat::Tensor::size_type’ {aka ‘long unsigned int’} [-Wsign-compare]
89 | TORCH_CHECK(num_layers == value_caches.size());
| ~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~
/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/c10/macros/Macros.h:207:64: note: in definition of macro ‘C10_UNLIKELY’
207 | #define C10_UNLIKELY(expr) (__builtin_expect(static_cast(expr), 0))
| ^~~~
/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/c10/util/Exception.h:503:7: note: in expansion of macro ‘C10_UNLIKELY_OR_CONST’
503 | if (C10_UNLIKELY_OR_CONST(!(cond))) {
| ^~~~~~~~~~~~~~~~~~~~~
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cache_impl.cpp:89:3: note: in expansion of macro ‘TORCH_CHECK’
89 | TORCH_CHECK(num_layers == value_caches.size());
| ^~~~~~~~~~~
[5/6] c++ -MMD -MF /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/layernorm_impl.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/TH -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/deepanshu/ai-tooling/vllm/csrc/cpu/layernorm_impl.cpp -o /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/layernorm_impl.o -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 -DVLLM_BUILD_CPU_ONLY -DVLLM_BUILD_CPU_OPS -fopenmp -mavx512f -mavx512bf16 -mavx512vl -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
FAILED: /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/layernorm_impl.o
c++ -MMD -MF /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/layernorm_impl.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/TH -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/deepanshu/ai-tooling/vllm/csrc/cpu/layernorm_impl.cpp -o /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/cpu/layernorm_impl.o -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 -DVLLM_BUILD_CPU_ONLY -DVLLM_BUILD_CPU_OPS -fopenmp -mavx512f -mavx512bf16 -mavx512vl -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
In file included from /home/deepanshu/ai-tooling/vllm/csrc/cpu/layernorm_impl.cpp:1:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec8::FP32Vec8(__m128bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:151:39: error: ‘_mm256_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm256_cvtph_ps’?
151 | explicit FP32Vec8(__m128bh v) : reg(_mm256_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm256_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In constructor ‘vec_op::FP32Vec16::FP32Vec16(__m256bh)’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:207:40: error: ‘_mm512_cvtpbh_ps’ was not declared in this scope; did you mean ‘_mm512_cvtph_ps’?
207 | explicit FP32Vec16(__m256bh v) : reg(_mm512_cvtpbh_ps(v)) {}
| ^~~~~~~~~~~~~~~~
| _mm512_cvtph_ps
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp: In function ‘void vec_op::storeFP32ToT(float, T
) [with T = c10::BFloat16]’:
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:21: error: ‘__bfloat16’ does not name a type; did you mean ‘__float80’?
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~
| __float80
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘>’ before ‘
’ token
268 | *reinterpret_cast<__bfloat16 >(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:32: error: expected ‘(’ before ‘
’ token
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^
| (
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:33: error: expected primary-expression before ‘>’ token
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:42: error: ‘_mm_cvtness_sbh’ was not declared in this scope; did you mean ‘_mm_cvtneps_pbh’?
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^~~~~~~~~~~~~~~
| _mm_cvtneps_pbh
/home/deepanshu/ai-tooling/vllm/csrc/cpu/cpu_types.hpp:268:60: error: expected ‘)’ before ‘;’ token
268 | *reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
| ^
| )
[6/6] c++ -MMD -MF /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/pybind.o.d -Wno-unused-result -Wsign-compare -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -g -fwrapv -O2 -fPIC -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/torch/csrc/api/include -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/TH -I/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/include/THC -I/usr/include/python3.10 -c -c /home/deepanshu/ai-tooling/vllm/csrc/pybind.cpp -o /home/deepanshu/ai-tooling/vllm/build/temp.linux-x86_64-cpython-310/csrc/pybind.o -g -O2 -std=c++17 -D_GLIBCXX_USE_CXX11_ABI=0 -DVLLM_BUILD_CPU_ONLY -DVLLM_BUILD_CPU_OPS -fopenmp -mavx512f -mavx512bf16 -mavx512vl -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1011"' -DTORCH_EXTENSION_NAME=_C -D_GLIBCXX_USE_CXX11_ABI=0
ninja: build stopped: subcommand failed.
Traceback (most recent call last):
File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2100, in _run_ninja_build
subprocess.run(
File "/usr/lib/python3.10/subprocess.py", line 526, in run
raise CalledProcessError(retcode, process.args,
subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.

  The above exception was the direct cause of the following exception:
  
  Traceback (most recent call last):
    File "/home/deepanshu/.local/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 353, in <module>
      main()
    File "/home/deepanshu/.local/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 335, in main
      json_out['return_val'] = hook(**hook_input['kwargs'])
    File "/home/deepanshu/.local/lib/python3.10/site-packages/pip/_vendor/pyproject_hooks/_in_process/_in_process.py", line 251, in build_wheel
      return _build_backend().build_wheel(wheel_directory, config_settings,
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 404, in build_wheel
      return self._build_with_temp_dir(
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 389, in _build_with_temp_dir
      self.run_setup()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/build_meta.py", line 311, in run_setup
      exec(code, locals())
    File "<string>", line 348, in <module>
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/__init__.py", line 103, in setup
      return distutils.core.setup(**attrs)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/core.py", line 185, in setup
      return run_commands(dist)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/core.py", line 201, in run_commands
      dist.run_commands()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 969, in run_commands
      self.run_command(cmd)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
      super().run_command(command)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/wheel/bdist_wheel.py", line 368, in run
      self.run_command("build")
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
      super().run_command(command)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build.py", line 131, in run
      self.run_command(cmd_name)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/cmd.py", line 318, in run_command
      self.distribution.run_command(command)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/dist.py", line 963, in run_command
      super().run_command(command)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/dist.py", line 988, in run_command
      cmd_obj.run()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 88, in run
      _build_ext.run(self)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 345, in run
      self.build_extensions()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 873, in build_extensions
      build_ext.build_extensions(self)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 467, in build_extensions
      self._build_extensions_serial()
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 493, in _build_extensions_serial
      self.build_extension(ext)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/command/build_ext.py", line 249, in build_extension
      _build_ext.build_extension(self, ext)
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/setuptools/_distutils/command/build_ext.py", line 548, in build_extension
      objects = self.compiler.compile(
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 686, in unix_wrap_ninja_compile
      _write_ninja_file_and_compile_objects(
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 1774, in _write_ninja_file_and_compile_objects
      _run_ninja_build(
    File "/tmp/pip-build-env-7p977_uw/overlay/local/lib/python3.10/dist-packages/torch/utils/cpp_extension.py", line 2116, in _run_ninja_build
      raise RuntimeError(message) from e
  RuntimeError: Error compiling objects for extension
  [end of output]

note: This error originates from a subprocess, and is likely not a problem with pip.
ERROR: Failed building wheel for vllm
Failed to build vllm
ERROR: Could not build wheels for vllm, which is required to install pyproject.toml-based projects`

@bigPYJ1151
Copy link
Contributor Author

@Deepansharora27 Seems it is due to the GCC version. This branch requires at least GCC-12.

@Deepansharora27
Copy link

@bigPYJ1151 Okay Let me See

@Deepansharora27
Copy link

@Deepansharora27 Seems it is due to the GCC version. This branch requires at least GCC-12.

Seems like I already have gcc-12 @bigPYJ1151

Screenshot 2023-12-19 at 12 42 11 PM

@bigPYJ1151
Copy link
Contributor Author

@Deepansharora27 Seems you also have g++11. Please check the default g++ version by g++ -v

@bigPYJ1151 bigPYJ1151 force-pushed the PR_Branch branch 2 times, most recently from 4eca588 to 8529fb7 Compare January 10, 2024 08:23
@sd3ntato
Copy link

Hello, I'd be very nice to have a docker image to run cpu vllm for local development!

@bigPYJ1151
Copy link
Contributor Author

Hello, I'd be very nice to have a docker image to run cpu vllm for local development!

@sd3ntato Thanks for your attention! Actually, nvcr.io/nvidia/pytorch:23.10-py3 is also available to run and develop CPU vllm, even without GPU.

@kruna6111
Copy link

I tried pip install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu but still getting this error.

Screenshot from 2024-02-01 11-20-10

@bigPYJ1151
Copy link
Contributor Author

Hi @kruna6111 , it seems the source code is not compiled and the operation library is not generated. Please try again with the following:

  • pip install torch==2.1.2+cpu --index-url https://download.pytorch.org/whl/cpu
  • VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=8 pip install --no-build-isolation -v -e .
    Make sure you have gcc/g++ version greater than 12.0 and your CPU supports AVX512-BF16.

@kruna6111
Copy link

kruna6111 commented Feb 1, 2024

@bigPYJ1151 Thanks for the help, I Implemented steps mentioned by you and it solved the error. however, a new error has popped up.

Screenshot from 2024-02-01 14-21-29

just to confirm,
currently

  • I have gcc/g++ version >12
  • Executed pip install torch==2.1.2+cpu --index-url https://download.pytorch.org/whl/cpu
  • Executing VLLM_BUILD_CPU_ONLY=1 and MAX_JOBS=8 before running my python file
  • pip install --no-build-isolation -v -e . executed this in the vllm project.
  • my cpu supports avx512cd avx512bw avx512vl avx512_vnni.

@bigPYJ1151
Copy link
Contributor Author

@kruna6111 seems you didn't specify the device type as --device=cpu. And your CPU has no support for AVX512-BF16🥲 so this branch will not work on your device.

@kruna6111
Copy link

kruna6111 commented Feb 1, 2024

Oh I see, @bigPYJ1151 Anyways Thanks for your help and support. Although, Is there any specific reason why it needs specifically AVX512-BF16?

@bigPYJ1151
Copy link
Contributor Author

@kruna6111 You can use float32 data type, but the performance is not comparable with native bfloat16. Torch-CPU needs AVX512-BF16 to natively support bfloat16 .

@kruna6111
Copy link

kruna6111 commented Feb 2, 2024

@bigPYJ1151 I am using float32 datatype as a parameter in dtype = torch.float32 while calling the LLM class of vllm and I am running the file as python3 example.py --device=cpu.

  • Getting this error :
    Screenshot from 2024-02-02 11-36-20

  • do i need to pass the float32 datatype like python3 example.py --device=cpu --dtype=float32? will it solve the above error?

@bigPYJ1151
Copy link
Contributor Author

@kruna6111 yes, you should pass --dtype=float32. It seems you want to import vLLM outside the vLLM dictionary, so we shouldn't build vLLM in editable mode. Please remove vLLM with pip and rebuild the vLLM with:
VLLM_BUILD_CPU_ONLY=1 MAX_JOBS=8 python setup.py install

@kruna6111
Copy link

kruna6111 commented Feb 2, 2024

hey @bigPYJ1151 , Implemented the steps mentioned, the issue was that there is a folder named vllm, which makes the import command prioritize the functions in this folder vllm over the vllm library. just needed to rename this folder. however still not able to run vllm inference on my CPU because of the cuda error.

Screenshot from 2024-02-02 18-01-46

@bigPYJ1151
Copy link
Contributor Author

@kruna6111 seems your LLM is initialized with cuda device. Please check the configurations of LLM class, make sure it is initialized as:

    llm = LLM(
        model=model,
        tokenizer=tokenizer,
        dtype="float32",
        enforce_eager=True,
        device="cpu",
        swap_space=swap_space,
        ...
    )

@kruna6111
Copy link

kruna6111 commented Feb 7, 2024

@bigPYJ1151 Thank you for your help, I am able to run inference on CPU with vllm. you are a Genius.

RUN --mount=type=cache,target=/root/.cache/pip \
pip install -r requirements-cpu.txt

FROM vllm-base AS vllm
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm-base is not available if build cpu.Dockerfile only

@WoosukKwon
Copy link
Collaborator

WoosukKwon commented Mar 25, 2024

@bigPYJ1151 Apologies for the very x100 significant delay. Could you please update this PR for review? vLLM has undergone numerous changes since your last update. Also, I'm curious about the current status. Is there any progress in TP and FP16 support?

@zhouyuan
Copy link
Contributor

link to the new patch: #3634

@WoosukKwon
Copy link
Collaborator

Closing this PR as we merged #3634

@WoosukKwon WoosukKwon closed this Apr 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.