Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ steps:
# NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py
- torchrun --nproc_per_node=2 distributed/test_ca_buffer_sharing.py
- TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus
- pytest -v -s -x lora/test_mixtral.py

Expand Down
59 changes: 59 additions & 0 deletions tests/distributed/test_ca_buffer_sharing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# can only run on machines with p2p access across GPUs
# can only run with torchrun:
# torchrun --nproc_per_node=2 tests/distributed/test_ca_buffer_sharing.py

import ctypes

import torch
import torch.distributed as dist

from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce import ( # noqa
CustomAllreduce)

# create a cpu process group for communicating metadata (ipc handle)
dist.init_process_group(backend="gloo")
rank = local_rank = dist.get_rank()
world_size = dist.get_world_size()

# every process sets its own device (differently)
lib = CudaRTLibrary()
lib.cudaSetDevice(rank)

buffer_size_in_bytes = 1024
byte_value = 2 # the value we write to the buffer for verification

pointers = CustomAllreduce.create_shared_buffer(buffer_size_in_bytes)

print(f"Rank {rank} has pointers {pointers}")

dist.barrier()
torch.cuda.synchronize()

if rank == 0:
# the first rank tries to write to all buffers
for p in pointers:
pointer = ctypes.c_void_p(p)
lib.cudaMemset(pointer, byte_value, buffer_size_in_bytes)

dist.barrier()
torch.cuda.synchronize()

host_data = (ctypes.c_char * buffer_size_in_bytes)()

# all ranks read from all buffers, and check if the data is correct
for p in pointers:
pointer = ctypes.c_void_p(p)
lib.cudaMemcpy(host_data, pointer, buffer_size_in_bytes)
for i in range(buffer_size_in_bytes):
assert ord(host_data[i]) == byte_value, (
f"Rank {rank} failed"
f" to verify buffer {p}. Expected {byte_value}, "
f"got {ord(host_data[i])}")

print(f"Rank {rank} verified all buffers")

dist.barrier()
torch.cuda.synchronize()

CustomAllreduce.free_shared_buffer(pointers)
31 changes: 31 additions & 0 deletions vllm/distributed/device_communicators/custom_all_reduce.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ctypes
from contextlib import contextmanager
from typing import Any, List, Optional, Union

Expand All @@ -7,6 +8,7 @@

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check)
from vllm.distributed.parallel_state import in_the_same_node_as
Expand Down Expand Up @@ -174,6 +176,35 @@ def __init__(self,
offsets, rank, self.full_nvlink)
self.register_buffer(self.buffer)

@staticmethod
def create_shared_buffer(
size_in_bytes: int,
group: Optional[ProcessGroup] = None) -> List[int]:
lib = CudaRTLibrary()
pointer = lib.cudaMalloc(size_in_bytes)
handle = lib.cudaIpcGetMemHandle(pointer)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to use broadcast with device=cpu in _gather_ipc_meta but not here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when you use this function, the group argument should be cpu_group passed to custom allreduce object.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see

if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is all_gather fine here but not in _gather_ipc_meta?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, that is because we met some issues with all_gather for tensors directly. here we are using all_gather_object , so it should be fine. see pytorch/pytorch#126032 for the pytorch issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Someone refactored this to return a Tensor
https://github.com/vllm-project/vllm/pull/5047/files#diff-44d9d733ee604800cbce9858a9201db1044aeff2c940fa4a0521d0c9b6541b3eL137

A better way should be returning a string, if torch bindings doesn't support int8 directly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah string should be fine. #5047 aims to get rid of pybind11 so that we can release python version agnostic wheels.


pointers: List[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer.value) # type: ignore
else:
pointers.append(
lib.cudaIpcOpenMemHandle(h).value) # type: ignore

return pointers

@staticmethod
def free_shared_buffer(pointers: List[int],
group: Optional[ProcessGroup] = None) -> None:
rank = dist.get_rank(group=group)
lib = CudaRTLibrary()
lib.cudaFree(ctypes.c_void_p(pointers[rank]))

@contextmanager
def capture(self):
"""
Expand Down