From 83c644fe7ecee05d3ebe5057acb6e008d7e81eb8 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 4 Aug 2024 00:22:19 -0700 Subject: [PATCH] [core][misc] simply output processing with shortcut code path (#7117) --- vllm/engine/output_processor/single_step.py | 39 ++++++++++++++++----- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 59eb4bc439d1f..4a46c93f84256 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -81,6 +81,29 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput) -> None: + sampling_params = seq_group.sampling_params + if sampling_params.n == 1 and not sampling_params.use_beam_search: + # only have one output sample + sample = outputs.samples[0] + # only have one sequence + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) + if sampling_params.detokenize and self.detokenizer: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + else: + new_char_count = 0 + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count, + sampling_params, + lora_req=seq_group.lora_request, + ) + if seq.is_finished(): + for scheduler in self.scheduler: + scheduler.free_seq(seq) + return + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) @@ -127,20 +150,20 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, child_seqs.append((parent, parent)) for seq, _ in child_seqs: - if seq_group.sampling_params.detokenize and self.detokenizer: + if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( - seq, seq_group.sampling_params) + seq, sampling_params) else: new_char_count = 0 self.stop_checker.maybe_stop_sequence( seq, new_char_count, - seq_group.sampling_params, + sampling_params, lora_req=seq_group.lora_request, ) # Non-beam search case - if not seq_group.sampling_params.use_beam_search: + if not sampling_params.use_beam_search: # For newly created child sequences, add them to the sequence group # and fork them in block manager if they are not finished. for seq, parent in child_seqs: @@ -164,8 +187,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the child sequences to keep in the sequence group. selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] - beam_width = seq_group.sampling_params.best_of - length_penalty = seq_group.sampling_params.length_penalty + beam_width = sampling_params.best_of + length_penalty = sampling_params.length_penalty # Select the newly finished sequences with the highest scores # to replace existing finished sequences. @@ -219,8 +242,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, best_running_seq = running_child_seqs[0][0] current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( - seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + sampling_params.early_stopping, sampling_params, + best_running_seq, current_worst_seq) if stop_beam_search: # Stop the beam search and remove all the running sequences from