Skip to content

[Serve] On kuberay, vLLM-0.7.2 reports "No CUDA GPUs are available" while vllm-0.6.6.post1 works fine when deploy rayservice #51154

@pteric

Description

@pteric

What happened + What you expected to happen

Description

When deploying Qwen2.5-0.5B model using kuberay with vLLM 0.7.2, encountering "RuntimeError: No CUDA GPUs are available" error. However, the same deployment works fine with vLLM 0.6.6.post1 under identical environment conditions.

Environment Information

  • Container Image: rayproject/ray:2.43.0-py39-cu124
  • vLLM:
    • Failed version: 0.7.2
    • Working version: 0.6.6.post1
  • Model: Qwen2.5-0.5B

Steps to Reproduce

Using kuberay to deploy RayService with image rayproject/ray:2.43.0-py39-cu124, the RayService is:

apiVersion: ray.io/v1
kind: RayService
metadata:
  name: qwen2005-0005b-vllm07
spec:
  serveConfigV2: |
    applications:
    - name: llm
      route_prefix: /
      import_path: latest-serve:model
      deployments:
      - name: VLLMDeployment
        num_replicas: 1
        ray_actor_options:
          num_cpus: 4
      runtime_env:
        working_dir: "https://xxx/vllm_script.zip"
        pip: 
          - "vllm==0.7.2"
        env_vars:
          MODEL_ID: "Qwen/Qwen2.5-0.5B"
          TENSOR_PARALLELISM: "1"
          PIPELINE_PARALLELISM: "1"
  rayClusterConfig:
    headGroupSpec:
      rayStartParams:
        dashboard-host: '0.0.0.0'
      template:
        spec:
          containers:
          - name: ray-head
            image: rayproject/ray:2.43.0-py39-cu124
            imagePullPolicy: IfNotPresent
            resources:
              limits:
                cpu: "8"
                memory: "16Gi"
              requests:
                cpu: "2"
                memory: "4Gi"
            ports:
            - containerPort: 6379
              name: gcs-server
            - containerPort: 8265
              name: dashboard
            - containerPort: 10001
              name: client
            - containerPort: 8000
              name: serve
            env:
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-secret
                  key: hf_api_token
    workerGroupSpecs:
    - replicas: 1
      minReplicas: 1
      maxReplicas: 2
      groupName: gpu-group
      rayStartParams: {}
      template:
        spec:
          containers:
          - name: llm
            image: rayproject/ray:2.43.0-py39-cu124
            imagePullPolicy: IfNotPresent
            env:
            - name: HUGGING_FACE_HUB_TOKEN
              valueFrom:
                secretKeyRef:
                  name: hf-secret
                  key: hf_api_token
            resources:
              limits:
                cpu: "8"
                memory: "16Gi"
                nvidia.com/gpu: "1"
              requests:
                cpu: "4"
                memory: "8Gi"
                nvidia.com/gpu: "1"
          tolerations:
            - key: "nvidia.com/gpu"
              operator: "Exists"
              effect: "NoSchedule"

and the latest-serve.py in https://xxx/vllm_script.zip is from: https://github.com/ray-project/ray/blob/master/doc/source/serve/doc_code/vllm_openai_example.py

The exception traceback:

[36mray::ServeReplica:llm:VLLMDeployment.initialize_and_get_metadata()�[39m (pid=1886, ip=10.58.29.125, actor_id=c3a99f2865a8a727c40545aa01000000, repr=<ray.serve._private.replica.ServeReplica:llm:VLLMDeployment object at 0x7f9966d21550>)
  File "/home/ray/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 446, in result
    return self.__get_result()
  File "/home/ray/anaconda3/lib/python3.9/concurrent/futures/_base.py", line 391, in __get_result
    raise self._exception
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 965, in initialize_and_get_metadata
    await self._replica_impl.initialize(deployment_config)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 694, in initialize
    raise RuntimeError(traceback.format_exc()) from None
RuntimeError: Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 671, in initialize
    self._user_callable_asgi_app = await asyncio.wrap_future(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 1363, in initialize_callable
    await self._call_func_or_gen(
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/_private/replica.py", line 1324, in _call_func_or_gen
    result = callable(*args, **kwargs)
  File "/home/ray/anaconda3/lib/python3.9/site-packages/ray/serve/api.py", line 221, in __init__
    cls.__init__(self, *args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/working_dir_files/https_aistudio-ant-mpc_oss-cn-zhangjiakou_aliyuncs_com_pengtuo_kuberay_vllm_script/latest-serve.py", line 57, in __init__
    self.engine = AsyncLLMEngine.from_engine_args(engine_args)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 644, in from_engine_args
    engine = cls(
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 594, in __init__
    self.engine = self._engine_class(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/async_llm_engine.py", line 267, in __init__
    super().__init__(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/engine/llm_engine.py", line 273, in __init__
    self.model_executor = executor_class(vllm_config=vllm_config, )
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/executor_base.py", line 51, in __init__
    self._init_executor()
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/uniproc_executor.py", line 41, in _init_executor
    self.collective_rpc("init_device")
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/executor/uniproc_executor.py", line 51, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/utils.py", line 2220, in run_method
    return func(*args, **kwargs)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/vllm/worker/worker.py", line 155, in init_device
    torch.cuda.set_device(self.device)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/torch/cuda/__init__.py", line 478, in set_device
    torch._C._cuda_setDevice(device)
  File "/tmp/ray/session_2025-03-06_22-57-24_631998_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/torch/cuda/__init__.py", line 319, in _lazy_init
    torch._C._cuda_init()
RuntimeError: No CUDA GPUs are available

Related issue

I've been searching for solutions and found two issues that match my symptoms, but the solutions provided in those issues don't work in my case:
vllm-project/vllm#6896
#50275

Versions / Dependencies

Ray image: rayproject/ray:2.43.0-py39-cu124
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.9.21 | packaged by conda-forge | (main, Dec  5 2024, 13:51:40)  [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.10.134-18.al8.x86_64-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A10
Nvidia driver version: 550.144.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.1.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.1.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.49.0
[pip3] triton==3.1.0

NVIDIA_VISIBLE_DEVICES=0
NVIDIA_REQUIRE_CUDA=cuda>=12.4 brand=tesla,driver>=470,driver<471 brand=unknown,driver>=470,driver<471 brand=nvidia,driver>=470,driver<471 brand=nvidiartx,driver>=470,driver<471 brand=geforce,driver>=470,driver<471 brand=geforcertx,driver>=470,driver<471 brand=quadro,driver>=470,driver<471 brand=quadrortx,driver>=470,driver<471 brand=titan,driver>=470,driver<471 brand=titanrtx,driver>=470,driver<471 brand=tesla,driver>=525,driver<526 brand=unknown,driver>=525,driver<526 brand=nvidia,driver>=525,driver<526 brand=nvidiartx,driver>=525,driver<526 brand=geforce,driver>=525,driver<526 brand=geforcertx,driver>=525,driver<526 brand=quadro,driver>=525,driver<526 brand=quadrortx,driver>=525,driver<526 brand=titan,driver>=525,driver<526 brand=titanrtx,driver>=525,driver<526 brand=tesla,driver>=535,driver<536 brand=unknown,driver>=535,driver<536 brand=nvidia,driver>=535,driver<536 brand=nvidiartx,driver>=535,driver<536 brand=geforce,driver>=535,driver<536 brand=geforcertx,driver>=535,driver<536 brand=quadro,driver>=535,driver<536 brand=quadrortx,driver>=535,driver<536 brand=titan,driver>=535,driver<536 brand=titanrtx,driver>=535,driver<536
NCCL_VERSION=2.21.5-1
NVIDIA_DRIVER_CAPABILITIES=compute,utility
NVIDIA_PRODUCT_NAME=CUDA
CUDA_VERSION=12.4.1
LD_LIBRARY_PATH=/tmp/ray/session_2025-03-06_04-45-27_752822_1/runtime_resources/pip/a425849cda8f3a2d8bc88454de4cdc8455c376c1/virtualenv/lib/python3.9/site-packages/cv2/../../lib64:/usr/local/nvidia/lib:/usr/local/nvidia/lib64
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

Reproduction script

the vLLM deployment code:

import os

from typing import Dict, Optional, List
import logging

from fastapi import FastAPI
from starlette.requests import Request
from starlette.responses import StreamingResponse, JSONResponse

from ray import serve

from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ErrorResponse,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    LoRAModulePath,
    PromptAdapterPath,
    OpenAIServingModels,
)

from vllm.utils import FlexibleArgumentParser
from vllm.entrypoints.logger import RequestLogger

logger = logging.getLogger("ray.serve")

app = FastAPI()


@serve.deployment(name="VLLMDeployment")
@serve.ingress(app)
class VLLMDeployment:
    def __init__(
        self,
        engine_args: AsyncEngineArgs,
        response_role: str,
        lora_modules: Optional[List[LoRAModulePath]] = None,
        prompt_adapters: Optional[List[PromptAdapterPath]] = None,
        request_logger: Optional[RequestLogger] = None,
        chat_template: Optional[str] = None,
    ):
        logger.info(f"Starting with engine args: {engine_args}")
        self.openai_serving_chat = None
        self.engine_args = engine_args
        self.response_role = response_role
        self.lora_modules = lora_modules
        self.prompt_adapters = prompt_adapters
        self.request_logger = request_logger
        self.chat_template = chat_template
        self.engine = AsyncLLMEngine.from_engine_args(engine_args)

    @app.post("/v1/chat/completions")
    async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request):
        if not self.openai_serving_chat:
            model_config = await self.engine.get_model_config()
            models = OpenAIServingModels(
                self.engine,
                model_config,
                [
                    BaseModelPath(
                        name=self.engine_args.model, model_path=self.engine_args.model
                    )
                ],
                lora_modules=self.lora_modules,
                prompt_adapters=self.prompt_adapters,
            )
            self.openai_serving_chat = OpenAIServingChat(
                self.engine,
                model_config,
                models,
                self.response_role,
                request_logger=self.request_logger,
                chat_template=self.chat_template,
                chat_template_content_format="auto",
            )
        logger.info(f"Request: {request}")
        generator = await self.openai_serving_chat.create_chat_completion(
            request, raw_request
        )
        if isinstance(generator, ErrorResponse):
            return JSONResponse(
                content=generator.model_dump(), status_code=generator.code
            )
        if request.stream:
            return StreamingResponse(content=generator, media_type="text/event-stream")
        else:
            assert isinstance(generator, ChatCompletionResponse)
            return JSONResponse(content=generator.model_dump())


def parse_vllm_args(cli_args: Dict[str, str]):
    arg_parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server."
    )
    parser = make_arg_parser(arg_parser)
    arg_strings = []
    for key, value in cli_args.items():
        arg_strings.extend([f"--{key}", str(value)])
    logger.info(arg_strings)
    parsed_args = parser.parse_args(args=arg_strings)
    return parsed_args


# serve run latest-serve:build_app model="Qwen/Qwen2.5-0.5B" tensor-parallel-size=1 accelerator="GPU"
def build_app(cli_args: Dict[str, str]) -> serve.Application:
    logger.info("*" * 100)
    if "accelerator" in cli_args.keys():
        accelerator = cli_args.pop("accelerator")
    else:
        accelerator = "GPU"
    parsed_args = parse_vllm_args(cli_args)

    engine_args = AsyncEngineArgs.from_cli_args(parsed_args)
    engine_args.worker_use_ray = True

    tp = engine_args.tensor_parallel_size
    logger.info(f"Tensor parallelism = {tp}")
    pg_resources = []
    pg_resources.append({"CPU": 4})  # for the deployment replica
    for i in range(tp):
        pg_resources.append({"CPU": 2, accelerator: 1})  # for the vLLM actors

    return VLLMDeployment.options(
        placement_group_bundles=pg_resources, placement_group_strategy="SPREAD"
    ).bind(
        engine_args,
        parsed_args.response_role,
        parsed_args.lora_modules,
        parsed_args.prompt_adapters,
        cli_args.get("request_logger"),
        parsed_args.chat_template,
    )

model = build_app({
    "model": os.environ['MODEL_ID'],
    "port": "8080",
    "tensor-parallel-size": os.environ['TENSOR_PARALLELISM'],
    "pipeline-parallel-size": os.environ['PIPELINE_PARALLELISM'],
    "max-model-len": os.environ['MODEL_LEN'],
    "gpu-memory-utilization": os.environ['GPU_MEMORY_UTILIZATION'],
    "dtype": os.environ['DTYPE'],
    "kv-cache-dtype": os.environ['KV_CACHE_DTYPE']
})

Issue Severity

High: It blocks me from completing my task.

Metadata

Metadata

Assignees

No one assigned

    Labels

    P1Issue that should be fixed within a few weeksbugSomething that is supposed to be working; but isn'tcommunity-backlogllmserveRay Serve Related Issue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions