64
64
"output_len" ,
65
65
[
66
66
# Use long output len for the small model test.
67
- 1536 ,
67
+ 256 ,
68
68
])
69
- @pytest .mark .parametrize ("batch_size" , [1 ])
69
+ @pytest .mark .parametrize ("batch_size" , [1 , 64 ])
70
70
@pytest .mark .parametrize ("seed" , [1 ])
71
- def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1 (
72
- baseline_llm_generator , test_llm_generator , batch_size : int ,
73
- output_len : int ):
71
+ def test_spec_decode_e2e_greedy_correctness_tiny_model (baseline_llm_generator ,
72
+ test_llm_generator ,
73
+ batch_size : int ,
74
+ output_len : int ):
74
75
"""Verify greedy equality on a tiny model with batch size of one.
75
76
76
77
Since this test is cheaper than other e2e correctness tests, we generate
@@ -83,51 +84,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
83
84
force_output_len = True )
84
85
85
86
86
- @pytest .mark .parametrize (
87
- "common_llm_kwargs" ,
88
- [{
89
- # Skip cuda graph recording for fast test.
90
- "enforce_eager" : True ,
91
-
92
- # Required for spec decode.
93
- "use_v2_block_manager" : True ,
94
-
95
- # Print spec metrics.
96
- "disable_log_stats" : False ,
97
- }])
98
- @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [
99
- {
100
- "model" : "JackFram/llama-68m" ,
101
- },
102
- ])
103
- @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
104
- @pytest .mark .parametrize ("test_llm_kwargs" , [
105
- {
106
- "speculative_model" : "[ngram]" ,
107
- "num_speculative_tokens" : 5 ,
108
- "ngram_prompt_lookup_max" : 3 ,
109
- },
110
- ])
111
- @pytest .mark .parametrize (
112
- "output_len" ,
113
- [
114
- # Use small output len for fast test.
115
- 256 ,
116
- ])
117
- @pytest .mark .parametrize ("batch_size" , [64 ])
118
- @pytest .mark .parametrize ("seed" , [1 ])
119
- def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs (
120
- baseline_llm_generator , test_llm_generator , batch_size : int ,
121
- output_len : int ):
122
- """Verify greedy equality on a tiny model and large batch size.
123
- """
124
- run_greedy_equality_correctness_test (baseline_llm_generator ,
125
- test_llm_generator ,
126
- batch_size ,
127
- max_output_len = output_len ,
128
- force_output_len = True )
129
-
130
-
131
87
@pytest .mark .parametrize (
132
88
"common_llm_kwargs" ,
133
89
[{
@@ -198,15 +154,15 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
198
154
"ngram_prompt_lookup_max" : 3 ,
199
155
}
200
156
# Try a range of common k, as well as large speculation.
201
- for k in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 63 ]
157
+ for k in [1 , 3 , 5 , 7 , 10 , 63 ]
202
158
] + [
203
159
{
204
160
"speculative_model" : "[ngram]" ,
205
161
"num_speculative_tokens" : k ,
206
162
"ngram_prompt_lookup_max" : 1 ,
207
163
}
208
164
# Try a range of common k, as well as large speculation.
209
- for k in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 63 ]
165
+ for k in [1 , 3 , 5 , 7 , 10 , 63 ]
210
166
])
211
167
@pytest .mark .parametrize ("batch_size" , [2 ])
212
168
@pytest .mark .parametrize (
0 commit comments