Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@

import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform

if current_platform.is_xla():
import torch_xla.core.xla_model as xm


@dataclass
Expand Down Expand Up @@ -125,6 +129,7 @@ class GroupCoordinator:
pynccl_comm: Optional[Any] # PyNccl communicator
ca_comm: Optional[Any] # Custom allreduce communicator
mq_broadcaster: Optional[Any] # shared memory broadcaster
use_xla: bool # Whether to use PyTorch XLA communicator
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does tpu platform support NCCL? if not, creating these communicators might lead to error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPU doesn't support NCCL, but I didn't see any error with the other communicators.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TPU backend uses gloo backend in addition to the distributed runtime in xm. Maybe that's the reason.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use_xxx is a initialization parameter, and we usually hold communicator inside group coordinator.

Can you add a tpu_communicator under https://github.com/vllm-project/vllm/tree/main/vllm/distributed/device_communicators ?

One additional benefit, is that you can implement the gather logic to allgather, without intrusive change to logits_processor.py .

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One additional benefit, is that you can implement the gather logic to allgather, without intrusive change to logits_processor.py .

This is actually not the case because the TPU backend explicitly requires all-gather, which means each device's output should not be None. If we implement gather by using all-gather and outputting None for non-root ranks, XLA will raise an error.


def __init__(
self,
Expand All @@ -140,6 +145,7 @@ def __init__(
self.local_rank = local_rank
self.device_group = None
self.cpu_group = None
self.use_xla = current_platform.is_xla()

for ranks in group_ranks:
device_group = torch.distributed.new_group(
Expand Down Expand Up @@ -289,6 +295,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# Bypass the function if we are using only 1 GPU.
if self.world_size == 1:
return input_

# For TPUs, use xm.all_reduce.
if self.use_xla:
return xm.all_reduce(xm.REDUCE_SUM, input_)

if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
Expand All @@ -307,6 +318,12 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

# For TPUs, use xm.all_gather.
if self.use_xla:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim)

if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
Expand Down
4 changes: 4 additions & 0 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,10 @@ def scale(self):
def soft_cap(self):
return self.base_layer.soft_cap

@property
def use_gather(self):
return self.base_layer.use_gather

@property
def org_vocab_size(self):
return self.base_layer.org_vocab_size
Expand Down
16 changes: 14 additions & 2 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import torch
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_gather)
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform


class LogitsProcessor(nn.Module):
Expand Down Expand Up @@ -39,6 +41,8 @@ def __init__(self,
self.org_vocab_size = org_vocab_size or vocab_size
# Soft cap the logits. Used in Gemma 2.
self.soft_cap = soft_cap
# Whether to use gather or all-gather to gather the logits.
self.use_gather = not current_platform.is_xla()

def forward(
self,
Expand Down Expand Up @@ -76,7 +80,15 @@ def _get_logits(self, hidden_states: torch.Tensor,
logits = lm_head.linear_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
logits = tensor_model_parallel_gather(logits)
if self.use_gather:
logits = tensor_model_parallel_gather(logits)
else:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def is_rocm(self) -> bool:
def is_tpu(self) -> bool:
return self._enum == PlatformEnum.TPU

def is_xla(self) -> bool:
return self.is_tpu()

@staticmethod
def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
raise NotImplementedError
Expand Down