9
9
import torch
10
10
11
11
from tests .nm_utils .utils_skip import should_skip_test_group
12
- from tests .utils import (init_test_distributed_environment ,
13
- multi_process_tensor_parallel )
14
- from vllm .distributed import (broadcast_tensor_dict ,
12
+ from vllm .distributed import (broadcast_tensor_dict , get_pp_group ,
15
13
tensor_model_parallel_all_gather ,
16
14
tensor_model_parallel_all_reduce )
17
15
16
+ from ..utils import init_test_distributed_environment , multi_process_parallel
17
+
18
18
if should_skip_test_group (group_name = "TEST_DISTRIBUTED" ):
19
19
pytest .skip ("TEST_DISTRIBUTED=DISABLE, skipping distributed test group" ,
20
20
allow_module_level = True )
@@ -109,6 +109,68 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
109
109
assert torch .allclose (recv_dict ["f" ], test_dict ["f" ])
110
110
111
111
112
+ @ray .remote (num_gpus = 1 , max_calls = 1 )
113
+ def send_recv_tensor_dict_test_worker (tp_size : int , pp_size : int , rank : int ,
114
+ distributed_init_port : str ):
115
+ del os .environ ["CUDA_VISIBLE_DEVICES" ]
116
+ device = torch .device (f"cuda:{ rank } " )
117
+ torch .cuda .set_device (device )
118
+ init_test_distributed_environment (tp_size , pp_size , rank ,
119
+ distributed_init_port )
120
+
121
+ test_dict = {
122
+ # device tensor
123
+ "a" : torch .arange (8 , dtype = torch .float32 , device = "cuda" ),
124
+ # CPU tensor
125
+ "b" : torch .arange (16 , dtype = torch .int8 , device = "cpu" ),
126
+ "c" : "test" ,
127
+ "d" : [1 , 2 , 3 ],
128
+ "e" : {
129
+ "a" : 1 ,
130
+ "b" : 2
131
+ },
132
+ # empty tensor
133
+ "f" : torch .tensor ([], dtype = torch .float32 , device = "cuda" ),
134
+ }
135
+
136
+ if not get_pp_group ().is_first_rank :
137
+ recv_dict = get_pp_group ().recv_tensor_dict ()
138
+
139
+ if not get_pp_group ().is_last_rank :
140
+ get_pp_group ().send_tensor_dict (test_dict )
141
+
142
+ if not get_pp_group ().is_first_rank :
143
+ assert len (recv_dict ) == len (test_dict )
144
+ assert torch .allclose (recv_dict ["a" ], test_dict ["a" ])
145
+ assert torch .allclose (recv_dict ["b" ], test_dict ["b" ])
146
+ assert recv_dict ["c" ] == test_dict ["c" ]
147
+ assert recv_dict ["d" ] == test_dict ["d" ]
148
+ assert recv_dict ["e" ] == test_dict ["e" ]
149
+ assert torch .allclose (recv_dict ["f" ], test_dict ["f" ])
150
+
151
+
152
+ @ray .remote (num_gpus = 1 , max_calls = 1 )
153
+ def send_recv_test_worker (tp_size : int , pp_size : int , rank : int ,
154
+ distributed_init_port : str ):
155
+ del os .environ ["CUDA_VISIBLE_DEVICES" ]
156
+ device = torch .device (f"cuda:{ rank } " )
157
+ torch .cuda .set_device (device )
158
+ init_test_distributed_environment (tp_size , pp_size , rank ,
159
+ distributed_init_port )
160
+
161
+ size = 64
162
+ test_tensor = torch .arange (64 , dtype = torch .float32 , device = "cuda" )
163
+
164
+ if not get_pp_group ().is_first_rank :
165
+ recv_tensor = get_pp_group ().recv (size , dtype = torch .float32 )
166
+
167
+ if not get_pp_group ().is_last_rank :
168
+ get_pp_group ().send (test_tensor )
169
+
170
+ if not get_pp_group ().is_first_rank :
171
+ assert torch .allclose (test_tensor , recv_tensor )
172
+
173
+
112
174
@pytest .mark .skipif (torch .cuda .device_count () < 2 ,
113
175
reason = "Need at least 2 GPUs to run the test." )
114
176
@pytest .mark .parametrize ("tp_size" , [2 ])
@@ -117,4 +179,13 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
117
179
broadcast_tensor_dict_test_worker
118
180
])
119
181
def test_multi_process_tensor_parallel (tp_size , test_target ):
120
- multi_process_tensor_parallel (tp_size , 1 , test_target )
182
+ multi_process_parallel (tp_size , 1 , test_target )
183
+
184
+
185
+ @pytest .mark .skipif (torch .cuda .device_count () < 2 ,
186
+ reason = "Need at least 2 GPUs to run the test." )
187
+ @pytest .mark .parametrize ("pp_size" , [2 ])
188
+ @pytest .mark .parametrize (
189
+ "test_target" , [send_recv_test_worker , send_recv_tensor_dict_test_worker ])
190
+ def test_multi_process_pipeline_parallel (pp_size , test_target ):
191
+ multi_process_parallel (1 , pp_size , test_target )
0 commit comments