Skip to content

Commit

Permalink
[Bugfix] Fix prompt_logprobs when SamplingParams.detokenize is set to…
Browse files Browse the repository at this point in the history
… True (#5226)
  • Loading branch information
zifeitong authored Jun 5, 2024
1 parent fee4dcc commit 974fc9b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
27 changes: 18 additions & 9 deletions tests/samplers/test_logprobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16, -1])
@pytest.mark.parametrize("num_top_logprobs", [6]) # 32000 == vocab_size
@pytest.mark.parametrize("detokenize", [True, False])
def test_get_prompt_logprobs(
hf_runner,
vllm_runner,
model,
dtype,
chunked_prefill_token_size: int,
num_top_logprobs: int,
detokenize: bool,
example_prompts,
):
max_num_seqs = 256
Expand Down Expand Up @@ -48,7 +50,8 @@ def test_get_prompt_logprobs(
vllm_sampling_params = SamplingParams(max_tokens=max_tokens,
logprobs=num_top_logprobs,
prompt_logprobs=num_top_logprobs,
temperature=0.0)
temperature=0.0,
detokenize=detokenize)
vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params)

Expand All @@ -65,11 +68,16 @@ def test_get_prompt_logprobs(
top_logprob = next(iter(top_logprobs.values()))
output_string_from_most_likely_tokens.append(
top_logprob.decoded_token)
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")

if detokenize:
output_string_from_most_likely_tokens = "".join(
output_string_from_most_likely_tokens)
assert output_text == output_string_from_most_likely_tokens, (
"The output text from the top logprob for each token position "
"should be the same as the output text in the result.")
else:
assert output_text == ''
assert output_string_from_most_likely_tokens == [None] * max_tokens

# The first prompt logprob is always None
assert result.prompt_logprobs[0] is None
Expand Down Expand Up @@ -98,9 +106,10 @@ def test_get_prompt_logprobs(
hf_logprob[i][-1][token_id].item(),
atol=1e-2,
rtol=1e-2)
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned "
" to the user.")
if detokenize:
assert isinstance(sample_logprob.decoded_token, str), (
"The token should be decoded by the time it is returned"
" to the user.")

# Test if prompt logprobs are correctly set.
for vllm_result in vllm_results:
Expand Down
8 changes: 4 additions & 4 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ def process_prompt_logprob(self, seq_group: SequenceGroup,
assert len(outputs) == 1, ("Single step should only has 1 output.")
output = outputs[0]
prompt_logprobs = output.prompt_logprobs
if (prompt_logprobs is not None
and seq_group.sampling_params.detokenize and self.detokenizer):
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if prompt_logprobs is not None:
if seq_group.sampling_params.detokenize and self.detokenizer:
self.detokenizer.decode_prompt_logprobs_inplace(
seq_group, prompt_logprobs)
if not seq_group.prompt_logprobs:
# The first prompt token's logprob is None because it doesn't
# have tokens that are precedent.
Expand Down

0 comments on commit 974fc9b

Please sign in to comment.