Skip to content

Commit

Permalink
[Test] Add ignore_eos test (vllm-project#4519)
Browse files Browse the repository at this point in the history
  • Loading branch information
rkooo567 authored May 1, 2024
1 parent d6f4bd7 commit 6f1df80
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tests/samplers/test_ignore_eos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Make sure ignore_eos works.
Run `pytest tests/samplers/test_ignore_eos.py`.
"""

import pytest

from vllm import SamplingParams

MODELS = ["facebook/opt-125m"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [1024])
def test_beam_search_single_input(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
example_prompts = "1 + 1 is"

vllm_model = vllm_runner(model, dtype=dtype)
sampling_params = SamplingParams(max_tokens=max_tokens, ignore_eos=True)
ignore_eos_output = vllm_model.model.generate(
example_prompts, sampling_params=sampling_params)
print(len(ignore_eos_output[0].outputs[0].token_ids))
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) < 10
assert max_tokens - len(ignore_eos_output[0].outputs[0].token_ids) >= 0

0 comments on commit 6f1df80

Please sign in to comment.