Skip to content

Commit

Permalink
[Bugfix] Fix for inconsistent behaviour related to sampling and repet…
Browse files Browse the repository at this point in the history
…ition penalties (#5639)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
  • Loading branch information
tdoublep authored Jun 18, 2024
1 parent 07feecd commit 8a17338
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,18 +386,10 @@ def from_sampling_metadata(
presence_penalties += [0] * prefill_len
frequency_penalties += [0] * prefill_len
repetition_penalties += [1] * prefill_len
if do_penalties:
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))

if seq_group.do_sample:
sample_lens = len(seq_group.sample_indices)
assert sample_lens == len(seq_ids)
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
if do_penalties:
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
temperatures += [temperature] * len(seq_ids)
top_ps += [top_p] * len(seq_ids)
top_ks += [top_k] * len(seq_ids)
Expand All @@ -424,6 +416,20 @@ def from_sampling_metadata(
sampling_seeds.append(seq_seeds)
sample_indices.extend(seq_group.sample_indices)

if do_penalties:
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
if (seq_group.is_prompt
and sampling_params.prompt_logprobs is not None):
prefill_len = len(seq_group.prompt_logprob_indices)
prompt_tokens.extend([] for _ in range(prefill_len))
output_tokens.extend([] for _ in range(prefill_len))
if seq_group.do_sample:
for seq_id in seq_ids:
seq_data = seq_group.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)

sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, sampling_seeds,
Expand Down

0 comments on commit 8a17338

Please sign in to comment.