Skip to content

Conversation

@FlorianJoncour
Copy link
Contributor

@FlorianJoncour FlorianJoncour commented Nov 28, 2023

I use several models using Ray Serve outside of Vllm.
Vllm ask for all ressources to Ray, making it impossible to use other models alongside.

So I use the gpu_memory_utilization parameter to limit the gpu ressources requested by the worker, which then allows placement_group_bundles to be used in Ray Serve deployments.

The RayWorker class has also been renamed to RayWorkerVllm to avoid ambiguities with other Ray actors.

Edit: Clarification, Vllm don't use all Vram, but requests for all ressources to Ray.

@matt-psaltis
Copy link

I think this feature fixes the underlying cause for the problem here: ray-project/ray-llm#94

Thanks @FlorianJoncour!

@FlorianJoncour
Copy link
Contributor Author

Yes it seems to be the same issue.

I almost went crazy trying to get this to work before diving into the vllm code

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, though I would say that the rename is unnecessary :)

One important thing to note is this doesn't actually provide a hard boundary on the process memory usage, so using fractional GPUs doesn't have a 100% guarantee that OOMs will be avoided. That being said, they should be quite unlikely thanks to vLLM memory profiling.

@zhuohan123 zhuohan123 merged commit 0229c38 into vllm-project:main Nov 29, 2023
@WoosukKwon
Copy link
Collaborator

@FlorianJoncour @Yard1 It seems this change causes a bug when gpu_memory_utilzation < 0.5 and tensor_parallel_size > 1:

RuntimeError: CUDA error: invalid device ordinal

I guess this is because num_gpus is set to gpu_memory_utilization? Do you have any idea to fix this?

@Yard1
Copy link
Collaborator

Yard1 commented Nov 30, 2023

@WoosukKwon I see. For now we should do:

num_gpus=self.cache_config.gpu_memory_utilization if self.parallel_config.tensor_parallel_size < 2 else 1

xjpang pushed a commit to xjpang/vllm that referenced this pull request Dec 4, 2023
Co-authored-by: FlorianJoncour <florian@zetta-sys.com>
hongxiayang pushed a commit to hongxiayang/vllm that referenced this pull request Feb 13, 2024
Co-authored-by: FlorianJoncour <florian@zetta-sys.com>
jinyouzhi pushed a commit to jinyouzhi/vllm that referenced this pull request Oct 24, 2025
…oject#2018)

Without this fix, the code will fail with a runtimeError below when the
list contains multi-element Tensors that are evaluated in a boolean
context.

ERROR 08-26 01:36:36 [engine.py:160] File
"/workspace/aicse.vllm-habana.demo/vllm/worker/hpu_model_runner.py",
line 3006, in prepare_model_input
ERROR 08-26 01:36:36 [engine.py:160] model_input, sampling_metadata =
self.prepare_input_tensors(
ERROR 08-26 01:36:36 [engine.py:160] File
"/workspace/aicse.vllm-habana.demo/vllm/worker/hpu_model_runner.py",
line 1925, in prepare_input_tensors
ERROR 08-26 01:36:36 [engine.py:160] ) =
self._prepare_prompt(prefill_reqs)
ERROR 08-26 01:36:36 [engine.py:160] File
"/workspace/aicse.vllm-habana.demo/vllm/worker/hpu_model_runner.py",
line 1389, in _prepare_prompt
ERROR 08-26 01:36:36 [engine.py:160]
self._get_mrope_positions_and_delta(
ERROR 08-26 01:36:36 [engine.py:160] File
"/workspace/aicse.vllm-habana.demo/vllm/worker/hpu_model_runner.py",
line 1213, in _get_mrope_positions_and_delta
ERROR 08-26 01:36:36 [engine.py:160]
MRotaryEmbedding.get_input_positions(
ERROR 08-26 01:36:36 [engine.py:160] File

"/workspace/aicse.vllm-habana.demo/vllm/model_executor/layers/rotary_embedding.py",
line 1174, in get_input_positions
ERROR 08-26 01:36:36 [engine.py:160]     cls.get_input_positions_tensor(
ERROR 08-26 01:36:36 [engine.py:160] File

"/workspace/aicse.vllm-habana.demo/vllm/model_executor/layers/rotary_embedding.py",
line 1216, in get_input_positions_tensor
ERROR 08-26 01:36:36 [engine.py:160] return
cls._vl_get_input_positions_tensor(
ERROR 08-26 01:36:36 [engine.py:160] File

"/workspace/aicse.vllm-habana.demo/vllm/model_executor/layers/rotary_embedding.py",
line 1284, in _vl_get_input_positions_tensor
ERROR 08-26 01:36:36 [engine.py:160]     if second_per_grid_ts:
ERROR 08-26 01:36:36 [engine.py:160] RuntimeError: Boolean value of
Tensor with more than one value is ambiguous

Signed-off-by: Haihao Xiang <haihao.xiang@intel.com>
Co-authored-by: Xiang, Haihao <haihao.xiang@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants