Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from vllm.platforms import current_platform

from ..conftest import HfRunner, VllmRunner
from ..core.block.e2e.test_correctness_sliding_window import prep_prompts
from ..utils import multi_gpu_test
from ..utils import multi_gpu_test, prep_prompts
from .utils import check_logprobs_close


Expand Down
47 changes: 47 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import importlib
import json
import os
import random
import signal
import subprocess
import sys
Expand Down Expand Up @@ -1150,3 +1151,49 @@ def override_cutlass_fp8_supported(value: bool):
"vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported",
return_value=value):
yield


def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)):
"""
Generate prompts which a bunch of assignments,
then asking for the value of one of them.
The prompt is just under 10k tokens; sliding window is 4k
so the answer is outside sliding window, but should still be correct.
Args:
batch_size: number of prompts to generate
ln_range: an argument to control the length of the prompt
"""
prompts: list[str] = []
answer: list[int] = []
indices: list[int] = []
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(*ln_range)
for k in range(30, ln):
v = random.randint(10, 99)
Comment on lines +1169 to +1177
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using random.seed() has a global effect and can interfere with other tests that rely on the random number generator. It's better to use an isolated random.Random instance for deterministic generation within this function.

Additionally, the prompt string construction can be simplified into a single f-string for better readability.

Suggested change
random.seed(1)
for _ in range(batch_size):
idx = random.randint(30, 90)
indices.append(idx)
prompt = "```python\n# We set a number of variables, " + \
f"x{idx} will be important later\n"
ln = random.randint(*ln_range)
for k in range(30, ln):
v = random.randint(10, 99)
rng = random.Random(1)
for _ in range(batch_size):
idx = rng.randint(30, 90)
indices.append(idx)
prompt = f"```python\n# We set a number of variables, x{idx} will be important later\n"
ln = rng.randint(*ln_range)
for k in range(30, ln):
v = rng.randint(10, 99)

if k == idx:
answer.append(v)
prompt += f"x{k} = {v}\n"
prompt += f"# Now, we check the value of x{idx}:\n"
prompt += f"assert x{idx} == "
prompts.append(prompt)
return prompts, answer, indices


def check_answers(indices: list[int],
answer: list[int],
outputs: list[str],
accept_rate: float = 0.7):
answer2 = [int(text[0:2].strip()) for text in outputs]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for parsing the model's output is brittle. If the output has leading whitespace, for example " 55" instead of "55", text[0:2] would be " 5", which incorrectly parses to 5.

A more robust approach is to strip whitespace from the whole string first, then split it and take the first part. This correctly handles leading/trailing whitespace.

Note that this can still raise an IndexError for empty outputs or a ValueError if the output is not a number, but this might be an acceptable risk for this test's context.

Suggested change
answer2 = [int(text[0:2].strip()) for text in outputs]
answer2 = [int(text.strip().split()[0]) for text in outputs]

print(list(zip(indices, zip(answer, answer2))))
numok = 0
for a1, a2 in zip(answer, answer2):
if a1 == a2:
numok += 1
frac_ok = numok / len(answer)
print(f"Num OK: {numok}/{len(answer)} {frac_ok}")
assert frac_ok >= accept_rate
3 changes: 1 addition & 2 deletions tests/v1/e2e/test_correctness_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from vllm import LLM, SamplingParams

from ...core.block.e2e.test_correctness_sliding_window import (check_answers,
prep_prompts)
from ...utils import check_answers, prep_prompts


@dataclass
Expand Down