Skip to content

Commit

Permalink
Remove allgather workaround in logits_processor (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzawora-intel authored Jul 1, 2024
1 parent aae39b1 commit 90f900c
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import torch
import torch.nn as nn

from vllm.distributed import tensor_model_parallel_gather, tensor_model_parallel_all_gather
from vllm.distributed import tensor_model_parallel_gather
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.utils import is_hpu


class LogitsProcessor(nn.Module):
"""Process logits and apply logits processors from sampling metadata.
Expand Down Expand Up @@ -50,9 +50,7 @@ def forward(
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)

# NOTE(kzawora): allgather on HPU will cause logits to be not None,
# and we need to guard against applying logits processors on non-driver worker
if logits is not None and sampling_metadata.seq_groups is not None:
if logits is not None:
logits *= self.scale

# Apply logits processors (if any).
Expand All @@ -66,9 +64,7 @@ def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
# NOTE(kzawora): HPU PT bridge is missing support for single-rank gather. We'll use all-gather on Gaudi for now.
gather_op = tensor_model_parallel_all_gather if is_hpu() else tensor_model_parallel_gather
logits = gather_op(logits)
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
Expand Down

0 comments on commit 90f900c

Please sign in to comment.