Skip to content

Commit

Permalink
[hardware][cuda] use device id under CUDA_VISIBLE_DEVICES for get_dev…
Browse files Browse the repository at this point in the history
…ice_capability (vllm-project#6216)
  • Loading branch information
youkaichao authored Jul 9, 2024
1 parent 4f0e0ea commit a3c9435
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pynvml. However, it should not initialize cuda context.
"""

import os
from functools import lru_cache, wraps
from typing import Tuple

Expand All @@ -23,12 +24,27 @@ def wrapper(*args, **kwargs):
return wrapper


@lru_cache(maxsize=8)
@with_nvml_context
def get_physical_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)


def device_id_to_physical_device_id(device_id: int) -> int:
if "CUDA_VISIBLE_DEVICES" in os.environ:
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
device_ids = [int(device_id) for device_id in device_ids]
physical_device_id = device_ids[device_id]
else:
physical_device_id = device_id
return physical_device_id


class CudaPlatform(Platform):
_enum = PlatformEnum.CUDA

@staticmethod
@lru_cache(maxsize=8)
@with_nvml_context
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
handle = pynvml.nvmlDeviceGetHandleByIndex(device_id)
return pynvml.nvmlDeviceGetCudaComputeCapability(handle)
physical_device_id = device_id_to_physical_device_id(device_id)
return get_physical_device_capability(physical_device_id)

0 comments on commit a3c9435

Please sign in to comment.