Skip to content

Commit 7bd7dae

Browse files
WoosukKwonAlvant
authored andcommitted
[TPU] Support collective communications in XLA devices (vllm-project#6813)
Signed-off-by: Alvant <alvasian@yandex.ru>
1 parent bd303a9 commit 7bd7dae

File tree

4 files changed

+70
-2
lines changed

4 files changed

+70
-2
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
import torch.distributed as dist
3+
from torch.distributed import ProcessGroup
4+
5+
from vllm.platforms import current_platform
6+
7+
if current_platform.is_tpu():
8+
import torch_xla.core.xla_model as xm
9+
from torch_xla._internal import pjrt
10+
11+
12+
class TpuCommunicator:
13+
14+
def __init__(self, group: ProcessGroup):
15+
if not current_platform.is_tpu():
16+
self.disabled = True
17+
return
18+
self.disabled = False
19+
20+
local_rank = dist.get_rank(group)
21+
world_size = dist.get_world_size(group)
22+
pjrt.initialize_multiprocess(local_rank, world_size)
23+
xm._init_world_size_ordinal()
24+
25+
def all_reduce(self, x: torch.Tensor) -> torch.Tensor:
26+
return xm.all_reduce(xm.REDUCE_SUM, x)
27+
28+
def all_gather(self, x: torch.Tensor, dim: int = -1) -> torch.Tensor:
29+
assert dim == -1, "TPUs only support dim=-1 for all-gather."
30+
return xm.all_gather(x, dim=dim)

vllm/distributed/parallel_state.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def __init__(
133133
torch_distributed_backend: Union[str, Backend],
134134
use_pynccl: bool,
135135
use_custom_allreduce: bool,
136+
use_tpu_communicator: bool,
136137
use_message_queue_broadcaster: bool = False,
137138
):
138139

@@ -164,6 +165,7 @@ def __init__(
164165

165166
self.use_pynccl = use_pynccl
166167
self.use_custom_allreduce = use_custom_allreduce
168+
self.use_tpu_communicator = use_tpu_communicator
167169

168170
# lazy import to avoid documentation build error
169171
from vllm.distributed.device_communicators.custom_all_reduce import (
@@ -190,6 +192,12 @@ def __init__(
190192
else:
191193
self.ca_comm = None
192194

195+
from vllm.distributed.device_communicators.tpu_communicator import (
196+
TpuCommunicator)
197+
self.tpu_communicator: Optional[TpuCommunicator]
198+
if use_tpu_communicator and self.world_size > 1:
199+
self.tpu_communicator = TpuCommunicator(group=self.cpu_group)
200+
193201
from vllm.distributed.device_communicators.shm_broadcast import (
194202
MessageQueue)
195203
self.mq_broadcaster: Optional[MessageQueue] = None
@@ -289,6 +297,12 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
289297
# Bypass the function if we are using only 1 GPU.
290298
if self.world_size == 1:
291299
return input_
300+
301+
# For TPUs, use TPU communicator.
302+
tpu_comm = self.tpu_communicator
303+
if tpu_comm is not None and not tpu_comm.disabled:
304+
return tpu_comm.all_reduce(input_)
305+
292306
if ca_comm is not None:
293307
out = ca_comm.custom_all_reduce(input_)
294308
if out is not None:
@@ -310,6 +324,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
310324
return input_
311325
assert -input_.dim() <= dim < input_.dim(), (
312326
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
327+
328+
# For TPUs, use TPU communicator.
329+
tpu_comm = self.tpu_communicator
330+
if tpu_comm is not None and not tpu_comm.disabled:
331+
return tpu_comm.all_gather(input_, dim)
332+
313333
if dim < 0:
314334
# Convert negative dim to positive.
315335
dim += input_.dim()
@@ -727,6 +747,7 @@ def init_world_group(ranks: List[int], local_rank: int,
727747
torch_distributed_backend=backend,
728748
use_pynccl=False,
729749
use_custom_allreduce=False,
750+
use_tpu_communicator=False,
730751
)
731752

732753

@@ -745,6 +766,7 @@ def init_model_parallel_group(
745766
torch_distributed_backend=backend,
746767
use_pynccl=True,
747768
use_custom_allreduce=use_custom_allreduce,
769+
use_tpu_communicator=True,
748770
use_message_queue_broadcaster=use_message_queue_broadcaster,
749771
)
750772

vllm/lora/layers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,6 +1067,10 @@ def scale(self):
10671067
def soft_cap(self):
10681068
return self.base_layer.soft_cap
10691069

1070+
@property
1071+
def use_gather(self):
1072+
return self.base_layer.use_gather
1073+
10701074
@property
10711075
def org_vocab_size(self):
10721076
return self.base_layer.org_vocab_size

vllm/model_executor/layers/logits_processor.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
import torch
66
import torch.nn as nn
77

8-
from vllm.distributed import tensor_model_parallel_gather
8+
from vllm.distributed import (tensor_model_parallel_all_gather,
9+
tensor_model_parallel_gather)
910
from vllm.model_executor.layers.vocab_parallel_embedding import (
1011
VocabParallelEmbedding)
1112
from vllm.model_executor.sampling_metadata import SamplingMetadata
13+
from vllm.platforms import current_platform
1214

1315

1416
class LogitsProcessor(nn.Module):
@@ -39,6 +41,8 @@ def __init__(self,
3941
self.org_vocab_size = org_vocab_size or vocab_size
4042
# Soft cap the logits. Used in Gemma 2.
4143
self.soft_cap = soft_cap
44+
# Whether to use gather or all-gather to gather the logits.
45+
self.use_gather = not current_platform.is_tpu()
4246

4347
def forward(
4448
self,
@@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor,
7680
logits = lm_head.linear_method.apply(lm_head,
7781
hidden_states,
7882
bias=embedding_bias)
79-
logits = tensor_model_parallel_gather(logits)
83+
if self.use_gather:
84+
logits = tensor_model_parallel_gather(logits)
85+
else:
86+
# Gather is not supported for some devices such as TPUs.
87+
# Use all-gather instead.
88+
# NOTE(woosuk): Here, the outputs of every device should not be None
89+
# because XLA requires strict SPMD among all devices. Every device
90+
# should execute the same operations after gathering the logits.
91+
logits = tensor_model_parallel_all_gather(logits)
8092
# Remove paddings in vocab (if any).
8193
if logits is not None:
8294
logits = logits[:, :self.org_vocab_size]

0 commit comments

Comments
 (0)