Skip to content

Commit

Permalink
[Bugfix][XPU] Fix xpu tp by introducing XpuCommunicator (vllm-project…
Browse files Browse the repository at this point in the history
…#10144)

Signed-off-by: yan ma <yan.ma@intel.com>
  • Loading branch information
yma11 authored Nov 8, 2024
1 parent f4c2187 commit f10797c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 22 deletions.
47 changes: 47 additions & 0 deletions vllm/distributed/device_communicators/xpu_communicator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup

from vllm.platforms import current_platform


class XpuCommunicator:

def __init__(self, group: ProcessGroup):
if not current_platform.is_xpu():
self.disabled = True
return
self.disabled = False
self.group = group
self.world_size = dist.get_world_size(self.group)

def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
dist.all_reduce(x, group=self.group)
return x

def gather(self,
input_: torch.Tensor,
rank_in_group: int,
dst: int = 0,
dim: int = -1):
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((self.world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.group)
if rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(self.world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor
40 changes: 18 additions & 22 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(
use_custom_allreduce: bool,
use_tpu_communicator: bool,
use_hpu_communicator: bool,
use_xpu_communicator: bool,
use_message_queue_broadcaster: bool = False,
group_name: Optional[str] = None,
):
Expand Down Expand Up @@ -214,6 +215,7 @@ def __init__(
self.use_custom_allreduce = use_custom_allreduce
self.use_tpu_communicator = use_tpu_communicator
self.use_hpu_communicator = use_hpu_communicator
self.use_xpu_communicator = use_xpu_communicator

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand Down Expand Up @@ -248,6 +250,12 @@ def __init__(
if use_hpu_communicator and self.world_size > 1:
self.hpu_communicator = HpuCommunicator(group=self.device_group)

from vllm.distributed.device_communicators.xpu_communicator import (
XpuCommunicator)
self.xpu_communicator: Optional[XpuCommunicator]
if use_xpu_communicator and self.world_size > 1:
self.xpu_communicator = XpuCommunicator(group=self.device_group)

from vllm.distributed.device_communicators.shm_broadcast import (
MessageQueue)
self.mq_broadcaster: Optional[MessageQueue] = None
Expand Down Expand Up @@ -373,6 +381,10 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
not self.hpu_communicator.disabled:
return self.hpu_communicator.all_reduce(input_)

if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.all_reduce(input_)

if self.ca_comm is not None and \
not self.ca_comm.disabled and \
self.ca_comm.should_custom_ar(input_):
Expand Down Expand Up @@ -459,28 +471,10 @@ def gather(self,
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
if current_platform.is_xpu():
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor
if self.xpu_communicator is not None and \
not self.xpu_communicator.disabled:
return self.xpu_communicator.gather(input_, self.rank_in_group,
dst, dim)
# Allocate output tensor.
if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
Expand Down Expand Up @@ -896,6 +890,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_custom_allreduce=False,
use_tpu_communicator=False,
use_hpu_communicator=False,
use_xpu_communicator=False,
group_name="world",
)

Expand All @@ -918,6 +913,7 @@ def init_model_parallel_group(
use_custom_allreduce=use_custom_allreduce,
use_tpu_communicator=True,
use_hpu_communicator=True,
use_xpu_communicator=True,
use_message_queue_broadcaster=use_message_queue_broadcaster,
group_name=group_name,
)
Expand Down

0 comments on commit f10797c

Please sign in to comment.