-
-
Notifications
You must be signed in to change notification settings - Fork 11k
[TPU] Support collective communications in XLA devices #6813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
af3a259
0f2abea
8ebea7e
782b182
8087227
f04e179
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,10 @@ | |
|
|
||
| import vllm.envs as envs | ||
| from vllm.logger import init_logger | ||
| from vllm.platforms import current_platform | ||
|
|
||
| if current_platform.is_xla(): | ||
| import torch_xla.core.xla_model as xm | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -125,6 +129,7 @@ class GroupCoordinator: | |
| pynccl_comm: Optional[Any] # PyNccl communicator | ||
| ca_comm: Optional[Any] # Custom allreduce communicator | ||
| mq_broadcaster: Optional[Any] # shared memory broadcaster | ||
| use_xla: bool # Whether to use PyTorch XLA communicator | ||
|
||
|
|
||
| def __init__( | ||
| self, | ||
|
|
@@ -140,6 +145,7 @@ def __init__( | |
| self.local_rank = local_rank | ||
| self.device_group = None | ||
| self.cpu_group = None | ||
| self.use_xla = current_platform.is_xla() | ||
|
|
||
| for ranks in group_ranks: | ||
| device_group = torch.distributed.new_group( | ||
|
|
@@ -289,6 +295,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: | |
| # Bypass the function if we are using only 1 GPU. | ||
| if self.world_size == 1: | ||
| return input_ | ||
|
|
||
| # For TPUs, use xm.all_reduce. | ||
| if self.use_xla: | ||
| return xm.all_reduce(xm.REDUCE_SUM, input_) | ||
|
|
||
| if ca_comm is not None: | ||
| out = ca_comm.custom_all_reduce(input_) | ||
| if out is not None: | ||
|
|
@@ -307,6 +318,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: | |
| return input_ | ||
| assert -input_.dim() <= dim < input_.dim(), ( | ||
| f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") | ||
|
|
||
| # For TPUs, use xm.all_gather. | ||
| if self.use_xla: | ||
| assert dim == -1, "TPUs only support dim=-1 for all-gather." | ||
| return xm.all_gather(input_, dim) | ||
|
|
||
| if dim < 0: | ||
| # Convert negative dim to positive. | ||
| dim += input_.dim() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does tpu platform support NCCL? if not, creating these communicators might lead to error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TPU doesn't support NCCL, but I didn't see any error with the other communicators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TPU backend uses gloo backend in addition to the distributed runtime in
xm. Maybe that's the reason.