|
8 | 8 | import importlib |
9 | 9 | import json |
10 | 10 | import os |
| 11 | +import random |
11 | 12 | import signal |
12 | 13 | import subprocess |
13 | 14 | import sys |
@@ -1150,3 +1151,49 @@ def override_cutlass_fp8_supported(value: bool): |
1150 | 1151 | "vllm.model_executor.layers.quantization.utils.w8a8_utils.cutlass_fp8_supported", |
1151 | 1152 | return_value=value): |
1152 | 1153 | yield |
| 1154 | + |
| 1155 | + |
| 1156 | +def prep_prompts(batch_size: int, ln_range: tuple[int, int] = (800, 1100)): |
| 1157 | + """ |
| 1158 | + Generate prompts which a bunch of assignments, |
| 1159 | + then asking for the value of one of them. |
| 1160 | + The prompt is just under 10k tokens; sliding window is 4k |
| 1161 | + so the answer is outside sliding window, but should still be correct. |
| 1162 | + Args: |
| 1163 | + batch_size: number of prompts to generate |
| 1164 | + ln_range: an argument to control the length of the prompt |
| 1165 | + """ |
| 1166 | + prompts: list[str] = [] |
| 1167 | + answer: list[int] = [] |
| 1168 | + indices: list[int] = [] |
| 1169 | + random.seed(1) |
| 1170 | + for _ in range(batch_size): |
| 1171 | + idx = random.randint(30, 90) |
| 1172 | + indices.append(idx) |
| 1173 | + prompt = "```python\n# We set a number of variables, " + \ |
| 1174 | + f"x{idx} will be important later\n" |
| 1175 | + ln = random.randint(*ln_range) |
| 1176 | + for k in range(30, ln): |
| 1177 | + v = random.randint(10, 99) |
| 1178 | + if k == idx: |
| 1179 | + answer.append(v) |
| 1180 | + prompt += f"x{k} = {v}\n" |
| 1181 | + prompt += f"# Now, we check the value of x{idx}:\n" |
| 1182 | + prompt += f"assert x{idx} == " |
| 1183 | + prompts.append(prompt) |
| 1184 | + return prompts, answer, indices |
| 1185 | + |
| 1186 | + |
| 1187 | +def check_answers(indices: list[int], |
| 1188 | + answer: list[int], |
| 1189 | + outputs: list[str], |
| 1190 | + accept_rate: float = 0.7): |
| 1191 | + answer2 = [int(text[0:2].strip()) for text in outputs] |
| 1192 | + print(list(zip(indices, zip(answer, answer2)))) |
| 1193 | + numok = 0 |
| 1194 | + for a1, a2 in zip(answer, answer2): |
| 1195 | + if a1 == a2: |
| 1196 | + numok += 1 |
| 1197 | + frac_ok = numok / len(answer) |
| 1198 | + print(f"Num OK: {numok}/{len(answer)} {frac_ok}") |
| 1199 | + assert frac_ok >= accept_rate |
0 commit comments