Skip to content

Commit ae1b51a

Browse files
committed
fix ray dep
1 parent 12a859f commit ae1b51a

File tree

5 files changed

+39
-6
lines changed

5 files changed

+39
-6
lines changed

csrc/cpu/torch_bindings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
#include <torch/library.h>
66

7-
void init_cpu_threads_env(const std::string& cpu_ids);
7+
std::string init_cpu_threads_env(const std::string& cpu_ids);
88

99
void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
1010
const torch::Tensor& b, const torch::Tensor& a_scales,
@@ -138,7 +138,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
138138

139139
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
140140
// CPU utils
141-
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
141+
utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env);
142142
}
143143

144144
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

csrc/cpu/utils.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
#include "cpu_types.hpp"
77

8-
void init_cpu_threads_env(const std::string& cpu_ids) {
8+
std::string init_cpu_threads_env(const std::string& cpu_ids) {
99
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
1010
TORCH_CHECK(omp_cpu_mask->size > 0);
1111
std::vector<int> omp_cpu_ids;
@@ -51,6 +51,12 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
5151
torch::set_num_threads((int)omp_cpu_ids.size());
5252
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
5353
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
54+
55+
std::vector<std::pair<int, int>> thread_core_mapping;
56+
thread_core_mapping.reserve(omp_cpu_ids.size());
57+
omp_lock_t writelock;
58+
omp_init_lock(&writelock);
59+
5460
#pragma omp parallel for schedule(static, 1)
5561
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
5662
cpu_set_t mask;
@@ -61,7 +67,24 @@ void init_cpu_threads_env(const std::string& cpu_ids) {
6167
TORCH_CHECK(false,
6268
"sched_setaffinity failed. errno: " + std::to_string(errno));
6369
}
70+
71+
omp_set_lock(&writelock);
72+
thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]);
73+
omp_unset_lock(&writelock);
6474
}
6575

76+
omp_destroy_lock(&writelock);
77+
6678
numa_free_nodemask(omp_cpu_mask);
79+
80+
std::stringstream ss;
81+
ss << "OMP threads binding of Process " << getpid() << ":\n";
82+
std::sort(thread_core_mapping.begin(), thread_core_mapping.end(),
83+
[](auto&& a, auto&& b) { return a.second < b.second; });
84+
for (auto&& item : thread_core_mapping) {
85+
ss << "\t"
86+
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
87+
}
88+
89+
return ss.str();
6790
}

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,8 @@ def __init__(
872872
from vllm.executor import ray_utils
873873
backend = "mp"
874874
ray_found = ray_utils.ray_is_available()
875-
if cuda_device_count_stateless() < self.world_size:
875+
if (torch.cuda.is_available()
876+
and cuda_device_count_stateless() < self.world_size):
876877
if not ray_found:
877878
raise ValueError("Unable to load Ray which is "
878879
"required for multi-node inference, "

vllm/executor/cpu_executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import torch
66

77
import vllm.envs as envs
8-
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
8+
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
9+
SchedulerConfig)
910
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
1011
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
1112
ResultHandler, WorkerMonitor)
@@ -60,6 +61,8 @@ def _init_executor(self) -> None:
6061
self.cache_config = _verify_and_get_cache_config(self.cache_config)
6162
self.scheduler_config = _verify_and_get_scheduler_config(
6263
self.scheduler_config)
64+
self.parallel_config = _verify_and_get_parallel_config(
65+
self.parallel_config)
6366

6467
# Multiprocessing-based executor does not support multi-node setting.
6568
# Since it only works for single node, we can use the loopback address
@@ -353,6 +356,11 @@ def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
353356
return config
354357

355358

359+
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
360+
config.distributed_executor_backend = "mp"
361+
return config
362+
363+
356364
def _driver_method_invoker(driver, method: str, *args, **kwargs):
357365
return getattr(driver, method)(*args, **kwargs)
358366

vllm/worker/cpu_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def stop_profile(self):
207207

208208
def init_device(self) -> None:
209209
if self.local_omp_cpuid != "all":
210-
torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
210+
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
211+
logger.info(ret)
211212

212213
self.init_distributed_environment()
213214
# Set random seed.

0 commit comments

Comments
 (0)