Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Hardware][TPU] Raise errors for unsupported sampling params (vllm-pr…
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and robertgshaw2-neuralmagic committed Jul 1, 2024
1 parent ece7c7f commit e6935bd
Showing 1 changed file with 44 additions and 19 deletions.
63 changes: 44 additions & 19 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
logger = init_logger(__name__)

_PAD_SLOT_ID = 0 # FIXME(woosuk)
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
_ENABLE_TOP_P = False


class TPUModelRunner:
Expand Down Expand Up @@ -339,9 +341,34 @@ def _prepare_sample(
assert seq_group_metadata.sampling_params is not None
sampling_params = seq_group_metadata.sampling_params

# NOTE(woosuk): Here we mimic argmax sampling by applying a very
# low temperature. This is not accurate.
t.append(sampling_params.temperature
if sampling_params.temperature >= 1e-5 else 1e-5)
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
raise NotImplementedError(
"Top-p sampling is currently disabled for the TPU backend "
"due to performance issues.")
p.append(sampling_params.top_p)
if sampling_params.top_k != -1:
raise NotImplementedError(
"Top-k sampling is currently disabled for the TPU backend "
"due to performance issues.")
if sampling_params.best_of > 1:
raise NotImplementedError(
"best_of > 1 is not currently supported by the TPU "
"backend.")
if sampling_params.use_beam_search:
raise NotImplementedError(
"Beam search is not supported by the TPU backend.")
if sampling_params.logprobs is not None:
raise NotImplementedError(
"logprobs is not currently supported by the TPU backend.")
if sampling_params.prompt_logprobs is not None:
raise NotImplementedError(
"prompt_logprobs is not currently supported by the TPU "
"backend.")

num_paddings = padded_batch_size - len(seq_group_metadata_list)
t += [1.0] * num_paddings
p += [1.0] * num_paddings
Expand All @@ -350,35 +377,32 @@ def _prepare_sample(
p = torch.tensor(p, dtype=torch.float32, device=self.device)
return t, p

def prepare_inputs(
def _execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
):
assert seq_group_metadata_list is not None
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
# Prepare inputs.
assert len(seq_group_metadata_list) > 0
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
if seq_group_metadata_list[0].is_prompt:
is_prompt = seq_group_metadata_list[0].is_prompt
if is_prompt:
inputs = self._prepare_prompt(seq_group_metadata_list)
else:
inputs = self._prepare_decode(seq_group_metadata_list)
padded_batch_size = inputs[0].shape[0]
sample_inputs = self._prepare_sample(seq_group_metadata_list,
padded_batch_size)
return inputs + sample_inputs
t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size)

def _execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> List[CompletionSequenceGroupOutput]:
inputs = self.prepare_inputs(seq_group_metadata_list)
# Execute the model.
next_token_ids = self.model(inputs[0], inputs[1], kv_caches,
*inputs[2:])
if not self.is_driver_worker:
return []
*inputs[2:], t, p)
# Retrieve the outputs to CPU.
next_token_ids = next_token_ids.cpu().tolist()

# NOTE(woosuk): Minimal code to construct the sampler outputs.
# The TPU backend does not reuse the sampler, since the TPU backend
# does not support the advanced sampling parameters such as logprobs.
i = 0
sampler_outputs = []
for seq_group_metadata in seq_group_metadata_list:
Expand All @@ -400,6 +424,7 @@ def execute_model(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> SamplerOutput:
assert seq_group_metadata_list is not None
assert len(seq_group_metadata_list) > 0
if seq_group_metadata_list[0].is_prompt:
# NOTE(woosuk): To reduce the compilation time, we only compile the
# prefill inputs with batch size 1. Because the scheduler is not
Expand Down Expand Up @@ -492,8 +517,8 @@ def forward(
logits = self.model.compute_logits(hidden_states, sampling_metadata)

logits = logits / t.unsqueeze(dim=1)
# FIXME(woosuk): Disabled top-p sampling since it's too slow.
# logits = _apply_top_p(logits, p.unsqueeze(dim=1))
if _ENABLE_TOP_P:
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
# FIXME(woosuk): best_of > 1 is not supported.
next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)
Expand Down

0 comments on commit e6935bd

Please sign in to comment.