Skip to content

Commit e08e665

Browse files
author
Muralidhar Andoorveedu
committed
Add send and recv helpers
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
1 parent 3f3b6b2 commit e08e665

File tree

6 files changed

+359
-8
lines changed

6 files changed

+359
-8
lines changed

tests/distributed/test_comm_ops.py

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
import torch
1010

1111
from vllm.distributed import (broadcast_tensor_dict,
12+
is_pipeline_model_parallel_first_rank,
13+
is_pipeline_model_parallel_last_rank,
14+
recv_tensor_dict, send_tensor_dict,
1215
tensor_model_parallel_all_gather,
1316
tensor_model_parallel_all_reduce)
1417

15-
from ..utils import (init_test_distributed_environment,
16-
multi_process_tensor_parallel)
18+
from ..utils import init_test_distributed_environment, multi_process_parallel
1719

1820

1921
@ray.remote(num_gpus=1, max_calls=1)
@@ -105,6 +107,46 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
105107
assert torch.allclose(recv_dict["f"], test_dict["f"])
106108

107109

110+
@ray.remote(num_gpus=1, max_calls=1)
111+
def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
112+
distributed_init_port: str):
113+
del os.environ["CUDA_VISIBLE_DEVICES"]
114+
device = torch.device(f"cuda:{rank}")
115+
torch.cuda.set_device(device)
116+
init_test_distributed_environment(tp_size, pp_size, rank,
117+
distributed_init_port)
118+
119+
test_dict = {
120+
# device tensor
121+
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
122+
# CPU tensor
123+
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
124+
"c": "test",
125+
"d": [1, 2, 3],
126+
"e": {
127+
"a": 1,
128+
"b": 2
129+
},
130+
# empty tensor
131+
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
132+
}
133+
134+
if not is_pipeline_model_parallel_first_rank():
135+
recv_dict = recv_tensor_dict()
136+
137+
if not is_pipeline_model_parallel_last_rank():
138+
send_tensor_dict(test_dict)
139+
140+
if not is_pipeline_model_parallel_first_rank():
141+
assert len(recv_dict) == len(test_dict)
142+
assert torch.allclose(recv_dict["a"], test_dict["a"])
143+
assert torch.allclose(recv_dict["b"], test_dict["b"])
144+
assert recv_dict["c"] == test_dict["c"]
145+
assert recv_dict["d"] == test_dict["d"]
146+
assert recv_dict["e"] == test_dict["e"]
147+
assert torch.allclose(recv_dict["f"], test_dict["f"])
148+
149+
108150
@pytest.mark.skipif(torch.cuda.device_count() < 2,
109151
reason="Need at least 2 GPUs to run the test.")
110152
@pytest.mark.parametrize("tp_size", [2])
@@ -113,4 +155,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
113155
broadcast_tensor_dict_test_worker
114156
])
115157
def test_multi_process_tensor_parallel(tp_size, test_target):
116-
multi_process_tensor_parallel(tp_size, 1, test_target)
158+
multi_process_parallel(tp_size, 1, test_target)
159+
160+
161+
@pytest.mark.skipif(torch.cuda.device_count() < 2,
162+
reason="Need at least 2 GPUs to run the test.")
163+
@pytest.mark.parametrize("pp_size", [2])
164+
@pytest.mark.parametrize("test_target", [send_recv_tensor_dict_test_worker])
165+
def test_multi_process_pipeline_parallel(pp_size, test_target):
166+
multi_process_parallel(1, pp_size, test_target)

tests/distributed/test_custom_all_reduce.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
get_tp_group, graph_capture)
1313

1414
from ..utils import (ensure_model_parallel_initialized,
15-
init_test_distributed_environment,
16-
multi_process_tensor_parallel)
15+
init_test_distributed_environment, multi_process_parallel)
1716

1817
random.seed(42)
1918
test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)]
@@ -113,4 +112,4 @@ def test_custom_allreduce(tp_size, pipeline_parallel_size, test_target):
113112
world_size = tp_size * pipeline_parallel_size
114113
if world_size > torch.cuda.device_count():
115114
pytest.skip("Not enough GPUs to run the test.")
116-
multi_process_tensor_parallel(tp_size, pipeline_parallel_size, test_target)
115+
multi_process_parallel(tp_size, pipeline_parallel_size, test_target)

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def init_test_distributed_environment(
129129
ensure_model_parallel_initialized(tp_size, pp_size)
130130

131131

132-
def multi_process_tensor_parallel(
132+
def multi_process_parallel(
133133
tp_size: int,
134134
pp_size: int,
135135
test_target,

vllm/distributed/communication_op.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
import torch.distributed
55

6-
from .parallel_state import get_tp_group
6+
from .parallel_state import get_pp_group, get_tp_group
77

88

99
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
@@ -30,3 +30,30 @@ def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
3030
if not torch.distributed.is_initialized():
3131
return tensor_dict
3232
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
33+
34+
35+
def send_tensor_dict(tensors: Dict[str, torch.Tensor],
36+
dst: Optional[int] = None) -> None:
37+
"""
38+
Send the tensors to the next pipeline model parallel rank.
39+
Args:
40+
tensors (Dict[torch.Tensor]): Dict of tensors to send.
41+
"""
42+
if dst is None:
43+
dst = get_pp_group().next_rank
44+
get_pp_group().send_tensor_dict(tensors, dst)
45+
46+
47+
def recv_tensor_dict(
48+
src: Optional[int] = None
49+
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
50+
"""
51+
Receive tensors from the previous pipeline model parallel rank assuming all
52+
tensors are the same size.
53+
Returns:
54+
Dict[torch.Tensor]: Dict of received tensors.
55+
"""
56+
if src is None:
57+
src = get_pp_group().prev_rank
58+
tensors = get_pp_group().recv_tensor_dict(src)
59+
return tensors

vllm/distributed/object_list_ops.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
This file is necessary until new version of torch.distributed is released with
3+
https://github.com/pytorch/pytorch/commit/b96b1e8cff029bb0a73283e6e7f6cc240313f1dc
4+
"""
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed.distributed_c10d import (_get_pg_default_device,
8+
_object_to_tensor,
9+
_tensor_to_object)
10+
11+
12+
def send_object_list(object_list, dst, group=None, device=None):
13+
"""
14+
Sends picklable objects in ``object_list`` synchronously.
15+
16+
Similar to :func:`send`, but Python objects can be passed in.
17+
Note that all objects in ``object_list`` must be picklable in order to be
18+
sent.
19+
20+
Args:
21+
object_list (List[Any]): List of input objects to sent.
22+
Each object must be picklable. Receiver must provide lists of
23+
equal sizes.
24+
dst (int): Destination rank to send ``object_list`` to.
25+
Destination rank is based on global process group
26+
(regardless of ``group`` argument)
27+
group: (ProcessGroup, optional): The process group to work on. If None,
28+
the default process group will be used. Default is ``None``.
29+
device (``torch.device``, optional): If not None, the objects are
30+
serialized and converted to tensors which are moved to the
31+
``device`` before sending. Default is ``None``.
32+
33+
Returns:
34+
``None``.
35+
"""
36+
if dist.get_rank() == dst:
37+
raise ValueError(
38+
"Invalid destination rank: destination rank should not be the "
39+
"same as the rank of the current process.")
40+
41+
# Current device selection.
42+
# To preserve backwards compatibility, ``device`` is default to ``None``
43+
# in which case we run current logic of device selection, i.e.
44+
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
45+
# case it is not ``None`` we move the size and object tensors to be
46+
# sent to this device.
47+
current_device = device or _get_pg_default_device(group)
48+
# Serialize object_list elements to tensors on src rank.
49+
tensor_list, size_list = zip(
50+
*
51+
[_object_to_tensor(obj, current_device, group) for obj in object_list])
52+
object_sizes_tensor = torch.cat(size_list)
53+
54+
# Send object sizes
55+
dist.send(object_sizes_tensor, dst=dst, group=group)
56+
57+
# Concatenate and send serialized object tensors
58+
# Note: torch.cat will do an extra memory copy to the current device,
59+
# if the tensor_list has only one element, we can skip the copy.
60+
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
61+
object_tensor = tensor_list[0]
62+
else:
63+
object_tensor = torch.cat(tensor_list)
64+
65+
dist.send(object_tensor, dst=dst, group=group)
66+
67+
68+
def recv_object_list(object_list, src=None, group=None, device=None):
69+
"""
70+
Receives picklable objects in ``object_list`` synchronously.
71+
72+
Similar to :func:`recv`, but can receive Python objects.
73+
74+
Args:
75+
object_list (List[Any]): List of objects to receive into.
76+
Must provide a list of sizes equal to the size of the list
77+
being sent.
78+
src (int, optional): Source rank from which to recv ``object_list``.
79+
Source rank is based on global process group
80+
(regardless of ``group`` argument)
81+
Will receive from any rank if set to None. Default is ``None``.
82+
group: (ProcessGroup, optional): The process group to work on. If None,
83+
the default process group will be used. Default is ``None``.
84+
device (``torch.device``, optional): If not None, receives on
85+
this device. Default is ``None``.
86+
87+
Returns:
88+
Sender rank. -1 if rank is not part of the group. If rank is part
89+
of the group, ``object_list`` will contain the sent objects from
90+
``src`` rank.
91+
"""
92+
93+
# Current device selection.
94+
# To preserve backwards compatibility, ``device`` is default to ``None``
95+
# in which case we run current logic of device selection, i.e.
96+
# ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the
97+
# case it is not ``None`` we move the size and object tensors to be
98+
# received to this device.
99+
current_device = device or _get_pg_default_device(group)
100+
object_sizes_tensor = torch.empty(len(object_list),
101+
dtype=torch.long,
102+
device=current_device)
103+
104+
# Receive object sizes
105+
rank_sizes = dist.recv(object_sizes_tensor, src=src, group=group)
106+
107+
# Tensor to receive serialized objects into.
108+
object_tensor = torch.empty( # type: ignore[call-overload]
109+
torch.sum(object_sizes_tensor).item(), # type: ignore[arg-type]
110+
dtype=torch.uint8,
111+
device=current_device)
112+
113+
rank_objects = dist.recv(object_tensor, src=src, group=group)
114+
assert (rank_sizes == rank_objects
115+
), "Mismatch in return ranks for object sizes and objects."
116+
# Deserialize objects using their stored sizes.
117+
offset = 0
118+
for i, obj_size in enumerate(object_sizes_tensor):
119+
obj_view = object_tensor[offset:offset + obj_size]
120+
obj_view = obj_view.type(torch.uint8)
121+
offset += obj_size
122+
object_list[i] = _tensor_to_object(obj_view, obj_size, group)
123+
return rank_objects

0 commit comments

Comments
 (0)