Skip to content

Commit 67308e4

Browse files
committed
use amdsmi for device name on rocm
Signed-off-by: Divakar Verma <divakar.verma@amd.com>
1 parent ce77eb9 commit 67308e4

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

vllm/platforms/rocm.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from functools import lru_cache
3+
import os
4+
from functools import lru_cache, wraps
45
from typing import TYPE_CHECKING, Dict, List, Optional
56

67
import torch
8+
from amdsmi import (amdsmi_get_gpu_asic_info, amdsmi_get_processor_handles,
9+
amdsmi_init, amdsmi_shut_down)
710

811
import vllm.envs as envs
912
from vllm.logger import init_logger
@@ -54,6 +57,28 @@
5457
}
5558

5659

60+
def with_amdsmi_context(fn):
61+
62+
@wraps(fn)
63+
def wrapper(*args, **kwargs):
64+
amdsmi_init()
65+
try:
66+
return fn(*args, **kwargs)
67+
finally:
68+
amdsmi_shut_down()
69+
70+
return wrapper
71+
72+
73+
def device_id_to_physical_device_id(device_id: int) -> int:
74+
if "CUDA_VISIBLE_DEVICES" in os.environ:
75+
device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
76+
physical_device_id = device_ids[device_id]
77+
return int(physical_device_id)
78+
else:
79+
return device_id
80+
81+
5782
class RocmPlatform(Platform):
5883
_enum = PlatformEnum.ROCM
5984
device_name: str = "rocm"
@@ -96,13 +121,12 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
96121
return DeviceCapability(major=major, minor=minor)
97122

98123
@classmethod
124+
@with_amdsmi_context
99125
@lru_cache(maxsize=8)
100126
def get_device_name(cls, device_id: int = 0) -> str:
101-
# NOTE: When using V1 this function is called when overriding the
102-
# engine args. Calling torch.cuda.get_device_name(device_id) here
103-
# will result in the ROCm context being initialized before other
104-
# processes can be created.
105-
return "AMD"
127+
physical_device_id = device_id_to_physical_device_id(device_id)
128+
handle = amdsmi_get_processor_handles()[physical_device_id]
129+
return amdsmi_get_gpu_asic_info(handle)["market_name"]
106130

107131
@classmethod
108132
def get_device_total_memory(cls, device_id: int = 0) -> int:

0 commit comments

Comments
 (0)