Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 54348ea

Browse files
andoorveRobert Shaw
authored andcommitted
[Distributed] Add send and recv helpers (vllm-project#5719)
1 parent 569c905 commit 54348ea

File tree

6 files changed

+279
-24
lines changed

6 files changed

+279
-24
lines changed

tests/distributed/test_comm_ops.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
import torch
1010

1111
from tests.nm_utils.utils_skip import should_skip_test_group
12-
from tests.utils import (init_test_distributed_environment,
13-
multi_process_tensor_parallel)
14-
from vllm.distributed import (broadcast_tensor_dict,
12+
from vllm.distributed import (broadcast_tensor_dict, get_pp_group,
1513
tensor_model_parallel_all_gather,
1614
tensor_model_parallel_all_reduce)
1715

16+
from ..utils import init_test_distributed_environment, multi_process_parallel
17+
1818
if should_skip_test_group(group_name="TEST_DISTRIBUTED"):
1919
pytest.skip("TEST_DISTRIBUTED=DISABLE, skipping distributed test group",
2020
allow_module_level=True)
@@ -109,6 +109,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
109109
assert torch.allclose(recv_dict["f"], test_dict["f"])
110110

111111

112+
@ray.remote(num_gpus=1, max_calls=1)
113+
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
114+
distributed_init_port: str):
115+
del os.environ["CUDA_VISIBLE_DEVICES"]
116+
device = torch.device(f"cuda:{rank}")
117+
torch.cuda.set_device(device)
118+
init_test_distributed_environment(tp_size, pp_size, rank,
119+
distributed_init_port)
120+
121+
test_dict = {
122+
# device tensor
123+
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
124+
# CPU tensor
125+
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
126+
"c": "test",
127+
"d": [1, 2, 3],
128+
"e": {
129+
"a": 1,
130+
"b": 2
131+
},
132+
# empty tensor
133+
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
134+
}
135+
136+
if not get_pp_group().is_first_rank:
137+
recv_dict = get_pp_group().recv_tensor_dict()
138+
139+
if not get_pp_group().is_last_rank:
140+
get_pp_group().send_tensor_dict(test_dict)
141+
142+
if not get_pp_group().is_first_rank:
143+
assert len(recv_dict) == len(test_dict)
144+
assert torch.allclose(recv_dict["a"], test_dict["a"])
145+
assert torch.allclose(recv_dict["b"], test_dict["b"])
146+
assert recv_dict["c"] == test_dict["c"]
147+
assert recv_dict["d"] == test_dict["d"]
148+
assert recv_dict["e"] == test_dict["e"]
149+
assert torch.allclose(recv_dict["f"], test_dict["f"])
150+
151+
152+
@ray.remote(num_gpus=1, max_calls=1)
153+
def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
154+
distributed_init_port: str):
155+
del os.environ["CUDA_VISIBLE_DEVICES"]
156+
device = torch.device(f"cuda:{rank}")
157+
torch.cuda.set_device(device)
158+
init_test_distributed_environment(tp_size, pp_size, rank,
159+
distributed_init_port)
160+
161+
size = 64
162+
test_tensor = torch.arange(64, dtype=torch.float32, device="cuda")
163+
164+
if not get_pp_group().is_first_rank:
165+
recv_tensor = get_pp_group().recv(size, dtype=torch.float32)
166+
167+
if not get_pp_group().is_last_rank:
168+
get_pp_group().send(test_tensor)
169+
170+
if not get_pp_group().is_first_rank:
171+
assert torch.allclose(test_tensor, recv_tensor)
172+
173+
112174
@pytest.mark.skipif(torch.cuda.device_count() < 2,
113175
reason="Need at least 2 GPUs to run the test.")
114176
@pytest.mark.parametrize("tp_size", [2])
@@ -117,4 +179,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
117179
broadcast_tensor_dict_test_worker
118180
])
119181
def test_multi_process_tensor_parallel(tp_size, test_target):
120-
multi_process_tensor_parallel(tp_size, 1, test_target)
182+
multi_process_parallel(tp_size, 1, test_target)
183+
184+
185+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
186+
reason="Need at least 2 GPUs to run the test.")
187+
@pytest.mark.parametrize("pp_size", [2])
188+
@pytest.mark.parametrize(
189+
"test_target", [send_recv_test_worker, send_recv_tensor_dict_test_worker])
190+
def test_multi_process_pipeline_parallel(pp_size, test_target):
191+
multi_process_parallel(1, pp_size, test_target)

tests/distributed/test_custom_all_reduce.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
get_tp_group, graph_capture)
1414

1515
from ..utils import (ensure_model_parallel_initialized,
16-
init_test_distributed_environment,
17-
multi_process_tensor_parallel)
16+
init_test_distributed_environment, multi_process_parallel)
1817

1918
if should_skip_test_group(group_name="TEST_DISTRIBUTED"):
2019
pytest.skip("TEST_DISTRIBUTED=DISABLE, skipping distributed test group",
@@ -133,4 +132,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
133132
world_size = tp_size * pipeline_parallel_size
134133
if world_size > torch.cuda.device_count():
135134
pytest.skip("Not enough GPUs to run the test.")
136-
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
135+
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

tests/distributed/test_pynccl.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,13 @@ def send_recv_worker_fn():
173173
dtype=torch.float32).cuda(pynccl_comm.rank)
174174
with pynccl_comm.change_state(enable=True):
175175
if pynccl_comm.rank == 0:
176-
pynccl_comm.send(tensor)
176+
pynccl_comm.send(tensor,
177+
dst=(pynccl_comm.rank + 1) %
178+
pynccl_comm.world_size)
177179
else:
178-
pynccl_comm.recv(tensor)
180+
pynccl_comm.recv(tensor,
181+
src=(pynccl_comm.rank - 1) %
182+
pynccl_comm.world_size)
179183
result = tensor.mean().cpu().item()
180184
assert result == 1
181185

@@ -208,9 +212,13 @@ def multiple_send_recv_worker_fn():
208212
device=device)
209213
with pynccl_comm.change_state(enable=True):
210214
if torch.distributed.get_rank() in [0, 1]:
211-
pynccl_comm.send(tensor)
215+
pynccl_comm.send(tensor,
216+
dst=(pynccl_comm.rank + 1) %
217+
pynccl_comm.world_size)
212218
else:
213-
pynccl_comm.recv(tensor)
219+
pynccl_comm.recv(tensor,
220+
src=(pynccl_comm.rank - 1) %
221+
pynccl_comm.world_size)
214222
result = tensor.mean().cpu().item()
215223
if torch.distributed.get_rank() in [0, 2]:
216224
assert result == 1

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def init_test_distributed_environment(
129129
ensure_model_parallel_initialized(tp_size, pp_size)
130130

131131

132-
def multi_process_tensor_parallel(
132+
def multi_process_parallel(
133133
tp_size: int,
134134
pp_size: int,
135135
test_target,

vllm/distributed/device_communicators/pynccl.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,36 +121,26 @@ def all_reduce(self,
121121
ncclRedOpTypeEnum.from_torch(op), self.comm,
122122
cudaStream_t(stream.cuda_stream))
123123

124-
def send(self,
125-
tensor: torch.Tensor,
126-
dst: Optional[int] = None,
127-
stream=None):
124+
def send(self, tensor: torch.Tensor, dst: int, stream=None):
128125
if self.disabled:
129126
return
130127
assert tensor.device == self.device, (
131128
f"this nccl communicator is created to work on {self.device}, "
132129
f"but the input tensor is on {tensor.device}")
133130
if stream is None:
134131
stream = self.stream
135-
if dst is None:
136-
dst = (self.rank + 1) % self.world_size
137132
self.nccl.ncclSend(buffer_type(tensor.data_ptr()), tensor.numel(),
138133
ncclDataTypeEnum.from_torch(tensor.dtype), dst,
139134
self.comm, cudaStream_t(stream.cuda_stream))
140135

141-
def recv(self,
142-
tensor: torch.Tensor,
143-
src: Optional[int] = None,
144-
stream=None):
136+
def recv(self, tensor: torch.Tensor, src: int, stream=None):
145137
if self.disabled:
146138
return
147139
assert tensor.device == self.device, (
148140
f"this nccl communicator is created to work on {self.device}, "
149141
f"but the input tensor is on {tensor.device}")
150142
if stream is None:
151143
stream = self.stream
152-
if src is None:
153-
src = (self.rank - 1) % self.world_size
154144
self.nccl.ncclRecv(buffer_type(tensor.data_ptr()), tensor.numel(),
155145
ncclDataTypeEnum.from_torch(tensor.dtype), src,
156146
self.comm, cudaStream_t(stream.cuda_stream))

0 commit comments

Comments
 (0)