diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 7ad84f51b7e4c..f95de56f39b57 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -386,18 +386,10 @@ def from_sampling_metadata( presence_penalties += [0] * prefill_len frequency_penalties += [0] * prefill_len repetition_penalties += [1] * prefill_len - if do_penalties: - prompt_tokens.extend([] for _ in range(prefill_len)) - output_tokens.extend([] for _ in range(prefill_len)) if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) assert sample_lens == len(seq_ids) - for seq_id in seq_ids: - seq_data = seq_group.seq_data[seq_id] - if do_penalties: - prompt_tokens.append(seq_data.prompt_token_ids) - output_tokens.append(seq_data.output_token_ids) temperatures += [temperature] * len(seq_ids) top_ps += [top_p] * len(seq_ids) top_ks += [top_k] * len(seq_ids) @@ -424,6 +416,20 @@ def from_sampling_metadata( sampling_seeds.append(seq_seeds) sample_indices.extend(seq_group.sample_indices) + if do_penalties: + for seq_group in sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + if (seq_group.is_prompt + and sampling_params.prompt_logprobs is not None): + prefill_len = len(seq_group.prompt_logprob_indices) + prompt_tokens.extend([] for _ in range(prefill_len)) + output_tokens.extend([] for _ in range(prefill_len)) + if seq_group.do_sample: + for seq_id in seq_ids: + seq_data = seq_group.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + sampling_tensors = SamplingTensors.from_lists( temperatures, top_ps, top_ks, min_ps, presence_penalties, frequency_penalties, repetition_penalties, sampling_seeds,