Skip to content

Commit

Permalink
[Core][Optimization] remove vllm-nccl (vllm-project#5091)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored May 29, 2024
1 parent 57905a6 commit 41b73e5
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 100 deletions.
1 change: 0 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ steps:
working_dir: "/vllm-workspace/tests"
num_gpus: 2
commands:
- pytest -v -s distributed/test_pynccl_library.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
Expand Down
1 change: 0 additions & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# Dependencies for NVIDIA GPUs
ray >= 2.9
nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.3.0
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.8.post2 # Requires PyTorch 2.3.0
7 changes: 2 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,8 @@ def _read_requirements(filename: str) -> List[str]:
cuda_major, cuda_minor = torch.version.cuda.split(".")
modified_requirements = []
for req in requirements:
if "vllm-nccl-cu12" in req:
req = req.replace("vllm-nccl-cu12",
f"vllm-nccl-cu{cuda_major}")
elif ("vllm-flash-attn" in req
and not (cuda_major == "12" and cuda_minor == "1")):
if ("vllm-flash-attn" in req
and not (cuda_major == "12" and cuda_minor == "1")):
# vllm-flash-attn is built only for CUDA 12.1.
# Skip for other versions.
continue
Expand Down
43 changes: 0 additions & 43 deletions tests/distributed/test_pynccl_library.py

This file was deleted.

20 changes: 7 additions & 13 deletions vllm/distributed/device_communicators/pynccl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch.distributed import ReduceOp

from vllm.logger import init_logger
from vllm.utils import find_nccl_library, nccl_integrity_check
from vllm.utils import find_nccl_library

logger = init_logger(__name__)

Expand Down Expand Up @@ -188,28 +188,22 @@ def __init__(self, so_file: Optional[str] = None):
so_file = so_file or find_nccl_library()

try:
# load the library in another process.
# if it core dumps, it will not crash the current process
nccl_integrity_check(so_file)
if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]
except Exception as e:
logger.error(
"Failed to load NCCL library from %s ."
"It is expected if you are not running on NVIDIA/AMD GPUs."
"Otherwise, the nccl library might not exist, be corrupted "
"or it does not support the current platform %s."
"One solution is to download libnccl2 version 2.18 from "
"https://developer.download.nvidia.com/compute/cuda/repos/ "
"and extract the libnccl.so.2 file. If you already have the "
"library, please set the environment variable VLLM_NCCL_SO_PATH"
"If you already have the library, please set the "
"environment variable VLLM_NCCL_SO_PATH"
" to point to the correct nccl library path.", so_file,
platform.platform())
raise e

if so_file not in NCCLLibrary.path_to_dict_mapping:
lib = ctypes.CDLL(so_file)
NCCLLibrary.path_to_library_cache[so_file] = lib
self.lib = NCCLLibrary.path_to_library_cache[so_file]

if so_file not in NCCLLibrary.path_to_dict_mapping:
_funcs = {}
for func in NCCLLibrary.exported_functions:
Expand Down
43 changes: 8 additions & 35 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
Expand Down Expand Up @@ -565,28 +564,6 @@ def init_cached_hf_modules():
init_hf_modules()


def nccl_integrity_check(filepath):
"""
when the library is corrupted, we cannot catch
the exception in python. it will crash the process.
instead, we use the exit code of `ldd` to check
if the library is corrupted. if not, we will return
the version of the library.
"""
exit_code = os.system(f"ldd {filepath} 2>&1 > /dev/null")
if exit_code != 0:
raise RuntimeError(f"Failed to load NCCL library from {filepath} .")
import ctypes

nccl = ctypes.CDLL(filepath)
version = ctypes.c_int()
nccl.ncclGetVersion.restype = ctypes.c_int
nccl.ncclGetVersion.argtypes = [ctypes.POINTER(ctypes.c_int)]
result = nccl.ncclGetVersion(ctypes.byref(version))
assert result == 0
return version.value


@lru_cache(maxsize=None)
def find_library(lib_name: str) -> str:
"""
Expand Down Expand Up @@ -616,17 +593,13 @@ def find_library(lib_name: str) -> str:


def find_nccl_library():
"""
We either use the library file specified by the `VLLM_NCCL_SO_PATH`
environment variable, or we find the library file brought by PyTorch.
After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be
found by `ctypes` automatically.
"""
so_file = envs.VLLM_NCCL_SO_PATH
VLLM_CONFIG_ROOT = envs.VLLM_CONFIG_ROOT

# check if we have vllm-managed nccl
vllm_nccl_path = None
if torch.version.cuda is not None:
cuda_major = torch.version.cuda.split(".")[0]
path = os.path.expanduser(
f"{VLLM_CONFIG_ROOT}/vllm/nccl/cu{cuda_major}/libnccl.so.*")
files = glob.glob(path)
vllm_nccl_path = files[0] if files else None

# manually load the nccl library
if so_file:
Expand All @@ -635,9 +608,9 @@ def find_nccl_library():
so_file)
else:
if torch.version.cuda is not None:
so_file = vllm_nccl_path or find_library("libnccl.so.2")
so_file = "libnccl.so.2"
elif torch.version.hip is not None:
so_file = find_library("librccl.so.1")
so_file = "librccl.so.1"
else:
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info("Found nccl from library %s", so_file)
Expand Down
6 changes: 4 additions & 2 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,14 @@ def update_environment_variables(envs: Dict[str, str]) -> None:

def init_worker(self, *args, **kwargs):
"""
Actual initialization of the worker class, and set up
function tracing if required.
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
enable_trace_function_call_for_thread()

# see https://github.com/NVIDIA/nccl/issues/1234
os.environ['NCCL_CUMEM_ENABLE'] = '0'

mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
self.worker = worker_class(*args, **kwargs)
Expand Down

0 comments on commit 41b73e5

Please sign in to comment.