diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2d8fffe5ac164..2c70c1f917a0d 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -20,6 +20,8 @@ logger = init_logger(__name__) _PAD_SLOT_ID = 0 # FIXME(woosuk) +# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. +_ENABLE_TOP_P = False class TPUModelRunner: @@ -339,9 +341,34 @@ def _prepare_sample( assert seq_group_metadata.sampling_params is not None sampling_params = seq_group_metadata.sampling_params + # NOTE(woosuk): Here we mimic argmax sampling by applying a very + # low temperature. This is not accurate. t.append(sampling_params.temperature if sampling_params.temperature >= 1e-5 else 1e-5) + if sampling_params.top_p != 1 and not _ENABLE_TOP_P: + raise NotImplementedError( + "Top-p sampling is currently disabled for the TPU backend " + "due to performance issues.") p.append(sampling_params.top_p) + if sampling_params.top_k != -1: + raise NotImplementedError( + "Top-k sampling is currently disabled for the TPU backend " + "due to performance issues.") + if sampling_params.best_of > 1: + raise NotImplementedError( + "best_of > 1 is not currently supported by the TPU " + "backend.") + if sampling_params.use_beam_search: + raise NotImplementedError( + "Beam search is not supported by the TPU backend.") + if sampling_params.logprobs is not None: + raise NotImplementedError( + "logprobs is not currently supported by the TPU backend.") + if sampling_params.prompt_logprobs is not None: + raise NotImplementedError( + "prompt_logprobs is not currently supported by the TPU " + "backend.") + num_paddings = padded_batch_size - len(seq_group_metadata_list) t += [1.0] * num_paddings p += [1.0] * num_paddings @@ -350,35 +377,32 @@ def _prepare_sample( p = torch.tensor(p, dtype=torch.float32, device=self.device) return t, p - def prepare_inputs( + def _execute_model( self, - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ): - assert seq_group_metadata_list is not None + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> List[CompletionSequenceGroupOutput]: + # Prepare inputs. assert len(seq_group_metadata_list) > 0 # NOTE: We assume that all sequences in the group are all prompts or # all decodes. - if seq_group_metadata_list[0].is_prompt: + is_prompt = seq_group_metadata_list[0].is_prompt + if is_prompt: inputs = self._prepare_prompt(seq_group_metadata_list) else: inputs = self._prepare_decode(seq_group_metadata_list) padded_batch_size = inputs[0].shape[0] - sample_inputs = self._prepare_sample(seq_group_metadata_list, - padded_batch_size) - return inputs + sample_inputs + t, p = self._prepare_sample(seq_group_metadata_list, padded_batch_size) - def _execute_model( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - ) -> List[CompletionSequenceGroupOutput]: - inputs = self.prepare_inputs(seq_group_metadata_list) + # Execute the model. next_token_ids = self.model(inputs[0], inputs[1], kv_caches, - *inputs[2:]) - if not self.is_driver_worker: - return [] + *inputs[2:], t, p) + # Retrieve the outputs to CPU. next_token_ids = next_token_ids.cpu().tolist() + # NOTE(woosuk): Minimal code to construct the sampler outputs. + # The TPU backend does not reuse the sampler, since the TPU backend + # does not support the advanced sampling parameters such as logprobs. i = 0 sampler_outputs = [] for seq_group_metadata in seq_group_metadata_list: @@ -400,6 +424,7 @@ def execute_model( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> SamplerOutput: assert seq_group_metadata_list is not None + assert len(seq_group_metadata_list) > 0 if seq_group_metadata_list[0].is_prompt: # NOTE(woosuk): To reduce the compilation time, we only compile the # prefill inputs with batch size 1. Because the scheduler is not @@ -492,8 +517,8 @@ def forward( logits = self.model.compute_logits(hidden_states, sampling_metadata) logits = logits / t.unsqueeze(dim=1) - # FIXME(woosuk): Disabled top-p sampling since it's too slow. - # logits = _apply_top_p(logits, p.unsqueeze(dim=1)) + if _ENABLE_TOP_P: + logits = _apply_top_p(logits, p.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1, dtype=torch.float32) # FIXME(woosuk): best_of > 1 is not supported. next_token_ids = torch.multinomial(probs, num_samples=1).squeeze(dim=1)