Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] avoid too many cuda context by caching p2p test #4021

Merged
merged 15 commits into from
Apr 14, 2024
53 changes: 21 additions & 32 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,33 @@ def init_custom_ar() -> None:
" disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
if not _can_p2p(rank, world_size):
num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
full_nvlink = _is_full_nvlink(rank, world_size)
if world_size > 2 and not full_nvlink:
logger.warn(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
# test P2P capability
# this is expensive to compute at the first time
# then we cache the result
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size, full_nvlink)


Expand Down Expand Up @@ -143,40 +157,15 @@ def _is_full_nvlink(rank, world_size):


def _can_p2p(rank: int, world_size: int) -> bool:
num_dev = torch.cuda.device_count()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if num_dev < world_size:
logger.warn(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set.")
return False
from vllm.distributed.utils import gpu_p2p_access_check
for i in range(world_size):
if i == rank:
continue
if not torch.cuda.can_device_access_peer(rank, i):
return False
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
if not _can_actually_p2p(rank, i):
if not gpu_p2p_access_check(rank, i):
return False
return True


# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c)


class CustomAllreduce:

# max_size: max supported allreduce size
Expand Down
9 changes: 9 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None

_LOCAL_RANK = -1


def get_local_rank():
global _LOCAL_RANK
return _LOCAL_RANK


def init_distributed_environment(
world_size: int,
Expand All @@ -60,6 +67,8 @@ def init_distributed_environment(
ranks = list(range(torch.distributed.get_world_size()))
_CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
backend="gloo")
global _LOCAL_RANK
_LOCAL_RANK = local_rank


def initialize_model_parallel(
Expand Down
87 changes: 86 additions & 1 deletion vllm/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,18 @@
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import Sequence
import json
import os
from typing import Dict, Optional, Sequence

import torch
import torch.distributed as dist

from vllm.logger import init_logger

from .parallel_state import get_cpu_world_group, get_local_rank

logger = init_logger(__name__)


def ensure_divisibility(numerator, denominator):
Expand Down Expand Up @@ -46,3 +55,79 @@ def split_tensor_along_last_dim(
return tuple(chunk.contiguous() for chunk in tensor_list)

return tensor_list


# code partly borrowed from
# https://github.com/turboderp/exllamav2/blob/1c67f97f3d2a968605a9c31ab791a05c85bb7879/exllamav2/compat.py#L10
# License: MIT
def _can_actually_p2p(idx_a, idx_b):
dev_i = f"cuda:{idx_a}"
dev_j = f"cuda:{idx_b}"
a = torch.randn(5, device=dev_i) + 123.0
b = a.to(dev_j)
c = b.to(dev_i)
return torch.all(a == c).cpu().item()


# why do we need this cache?
# 1. we can have runtime checks for P2P access, where every process checks
# P2P access to all other GPUs. Unfortunately, the test might cost many
# (world_size * world_size) cuda context, and reduce the memory available
# for the model. see https://github.com/vllm-project/vllm/issues/3821
# 2. alternatively, we can have a p2p map that is generated by the master
# process and broadcasted to all other processes. This still requires
# #world_size of cuda context, belonging to the master process, on each GPU.
# 3. we can have a cache file, that records the p2p access status. The first
# time the master process checks the p2p access, it will generate the cache
# file, at the cost of #world_size of cuda context. Later on, all processes
# can read the cache file to check the p2p access status without any cost of
# additional cuda context.
# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we
# can have different cache files for different CUDA_VISIBLE_DEVICES settings,
# e.g. used by different vllm engines. The device id in the cache file is a
# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number
# of visible devices in the vllm engine.
_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None


def gpu_p2p_access_check(i: int, j: int) -> bool:
"""Check if GPU i can access GPU j."""

# if the cache variable is already calculated,
# read from the cache instead of checking it again
global _gpu_p2p_access_cache
if _gpu_p2p_access_cache is not None:
return _gpu_p2p_access_cache[f"{i}->{j}"]

is_distributed = dist.is_initialized()

num_dev = torch.cuda.device_count()
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
path = os.path.expanduser(
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
f"~/.config/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json")
os.makedirs(os.path.dirname(path), exist_ok=True)
if (not is_distributed or get_local_rank() == 0) \
and (not os.path.exists(path)):
# only the local master process (with local_rank == 0) can
# enter this block to calculate the cache
logger.info(f"generating GPU P2P access cache for in {path}")
cache = {}
for _i in range(num_dev):
for _j in range(num_dev):
# on some platforms, P2P support might be buggy and we need
# additional checks. See also:
# https://github.com/vllm-project/vllm/issues/2728
cache[f"{_i}->{_j}"] = torch.cuda.can_device_access_peer(
_i, _j) and _can_actually_p2p(_i, _j)
with open(path, "w") as f:
json.dump(cache, f, indent=4)
if is_distributed:
cpu_world_group = get_cpu_world_group()
dist.barrier(cpu_world_group)
logger.info(f"reading GPU P2P access cache from {path}")
with open(path, "r") as f:
cache = json.load(f)
_gpu_p2p_access_cache = cache
return _gpu_p2p_access_cache[f"{i}->{j}"]
Loading