Skip to content

Commit 748c687

Browse files
wenlei03root
authored andcommitted
Addressing Cody' note
1 parent f6b8afe commit 748c687

File tree

4 files changed

+13
-58
lines changed

4 files changed

+13
-58
lines changed

tests/spec_decode/e2e/test_compatibility.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
9090
@pytest.mark.parametrize(
9191
"common_llm_kwargs",
9292
[{
93-
"model": "NousResearch/Llama-2-7b-chat-hf",
93+
"model": "meta-llama/Llama-2-7b-chat-hf",
9494
"speculative_model": "JackFram/llama-68m",
9595
"num_speculative_tokens": 5,
9696
@@ -112,7 +112,7 @@ def test_spec_decode_xfail_chunked_prefill(test_llm_generator):
112112
},
113113
{
114114
# Speculative max model len > target max model len should raise.
115-
# https://huggingface.co/NousResearch/Llama-2-7b-chat-hf/blob/37892f30c23786c0d5367d80481fa0d9fba93cf8/config.json#L11
115+
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/f5db02db724555f92da89c216ac04704f23d4590/config.json#L12
116116
"speculative_max_model_len": 4096 + 1,
117117
},
118118
])

tests/spec_decode/e2e/test_multistep_correctness.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs_diff_output_len(
264264
"common_llm_kwargs",
265265
[{
266266
# A "real" model (not tiny).
267-
"model": "NousResearch/Llama-2-7b-chat-hf",
267+
"model": "meta-llama/Llama-2-7b-chat-hf",
268268
269269
# Skip cuda graph recording for fast test.
270270
"enforce_eager": True,
@@ -308,7 +308,7 @@ def test_spec_decode_e2e_greedy_correctness_real_model_bs1(
308308
"common_llm_kwargs",
309309
[{
310310
# A "real" model (not tiny).
311-
"model": "NousResearch/Llama-2-7b-chat-hf",
311+
"model": "meta-llama/Llama-2-7b-chat-hf",
312312
313313
# Skip cuda graph recording for fast test.
314314
"enforce_eager": True,

tests/spec_decode/e2e/test_ngram_correctness.py

Lines changed: 8 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,14 @@
6464
"output_len",
6565
[
6666
# Use long output len for the small model test.
67-
1536,
67+
256,
6868
])
69-
@pytest.mark.parametrize("batch_size", [1])
69+
@pytest.mark.parametrize("batch_size", [1, 64])
7070
@pytest.mark.parametrize("seed", [1])
71-
def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
72-
baseline_llm_generator, test_llm_generator, batch_size: int,
73-
output_len: int):
71+
def test_spec_decode_e2e_greedy_correctness_tiny_model(baseline_llm_generator,
72+
test_llm_generator,
73+
batch_size: int,
74+
output_len: int):
7475
"""Verify greedy equality on a tiny model with batch size of one.
7576
7677
Since this test is cheaper than other e2e correctness tests, we generate
@@ -83,51 +84,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
8384
force_output_len=True)
8485

8586

86-
@pytest.mark.parametrize(
87-
"common_llm_kwargs",
88-
[{
89-
# Skip cuda graph recording for fast test.
90-
"enforce_eager": True,
91-
92-
# Required for spec decode.
93-
"use_v2_block_manager": True,
94-
95-
# Print spec metrics.
96-
"disable_log_stats": False,
97-
}])
98-
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
99-
{
100-
"model": "JackFram/llama-68m",
101-
},
102-
])
103-
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
104-
@pytest.mark.parametrize("test_llm_kwargs", [
105-
{
106-
"speculative_model": "[ngram]",
107-
"num_speculative_tokens": 5,
108-
"ngram_prompt_lookup_max": 3,
109-
},
110-
])
111-
@pytest.mark.parametrize(
112-
"output_len",
113-
[
114-
# Use small output len for fast test.
115-
256,
116-
])
117-
@pytest.mark.parametrize("batch_size", [64])
118-
@pytest.mark.parametrize("seed", [1])
119-
def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs(
120-
baseline_llm_generator, test_llm_generator, batch_size: int,
121-
output_len: int):
122-
"""Verify greedy equality on a tiny model and large batch size.
123-
"""
124-
run_greedy_equality_correctness_test(baseline_llm_generator,
125-
test_llm_generator,
126-
batch_size,
127-
max_output_len=output_len,
128-
force_output_len=True)
129-
130-
13187
@pytest.mark.parametrize(
13288
"common_llm_kwargs",
13389
[{
@@ -198,15 +154,15 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
198154
"ngram_prompt_lookup_max": 3,
199155
}
200156
# Try a range of common k, as well as large speculation.
201-
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
157+
for k in [1, 3, 5, 7, 10, 63]
202158
] + [
203159
{
204160
"speculative_model": "[ngram]",
205161
"num_speculative_tokens": k,
206162
"ngram_prompt_lookup_max": 1,
207163
}
208164
# Try a range of common k, as well as large speculation.
209-
for k in [1, 2, 3, 4, 5, 6, 7, 8, 9, 63]
165+
for k in [1, 3, 5, 7, 10, 63]
210166
])
211167
@pytest.mark.parametrize("batch_size", [2])
212168
@pytest.mark.parametrize(

vllm/spec_decode/ngram_worker.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,7 @@ def sampler_output(
115115
ngram_size + sample_len]
116116
res_len = len(res)
117117
# pad 0 towards output as sample_len tokens required
118-
for i in range(res_len, sample_len):
119-
res.append(0)
118+
res += [0] * (sample_len - res_len)
120119

121120
break
122121
else:

0 commit comments

Comments
 (0)