We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 62effc7 commit f6b992aCopy full SHA for f6b992a
vllm/model_executor/models/qwen2_5_vl.py
@@ -198,8 +198,11 @@ def forward(self, x: torch.Tensor):
198
199
def all_gather_interleave(local_tensor, hidden_size: int, tp_size: int):
200
"""All-gather the input tensor interleavely across model parallel group."""
201
+ import torch.distributed as dist
202
gathered_tensors = [torch.zeros_like(local_tensor) for _ in range(tp_size)]
- parallel_state.get_tp_group().all_gather(gathered_tensors, local_tensor)
203
+ dist.all_gather(gathered_tensors,
204
+ local_tensor,
205
+ group=parallel_state.get_tp_group().device_group)
206
207
gathered_tensors_split = [
208
torch.split(tensor, hidden_size // tp_size, -1)
0 commit comments