6464 "output_len" ,
6565 [
6666 # Use long output len for the small model test.
67- 1536 ,
67+ 256 ,
6868 ])
69- @pytest .mark .parametrize ("batch_size" , [1 ])
69+ @pytest .mark .parametrize ("batch_size" , [1 , 64 ])
7070@pytest .mark .parametrize ("seed" , [1 ])
71- def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1 (
72- baseline_llm_generator , test_llm_generator , batch_size : int ,
73- output_len : int ):
71+ def test_spec_decode_e2e_greedy_correctness_tiny_model (baseline_llm_generator ,
72+ test_llm_generator ,
73+ batch_size : int ,
74+ output_len : int ):
7475 """Verify greedy equality on a tiny model with batch size of one.
7576
7677 Since this test is cheaper than other e2e correctness tests, we generate
@@ -83,51 +84,6 @@ def test_spec_decode_e2e_greedy_correctness_tiny_model_bs1(
8384 force_output_len = True )
8485
8586
86- @pytest .mark .parametrize (
87- "common_llm_kwargs" ,
88- [{
89- # Skip cuda graph recording for fast test.
90- "enforce_eager" : True ,
91-
92- # Required for spec decode.
93- "use_v2_block_manager" : True ,
94-
95- # Print spec metrics.
96- "disable_log_stats" : False ,
97- }])
98- @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [
99- {
100- "model" : "JackFram/llama-68m" ,
101- },
102- ])
103- @pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
104- @pytest .mark .parametrize ("test_llm_kwargs" , [
105- {
106- "speculative_model" : "[ngram]" ,
107- "num_speculative_tokens" : 5 ,
108- "ngram_prompt_lookup_max" : 3 ,
109- },
110- ])
111- @pytest .mark .parametrize (
112- "output_len" ,
113- [
114- # Use small output len for fast test.
115- 256 ,
116- ])
117- @pytest .mark .parametrize ("batch_size" , [64 ])
118- @pytest .mark .parametrize ("seed" , [1 ])
119- def test_spec_decode_e2e_greedy_correctness_tiny_model_large_bs (
120- baseline_llm_generator , test_llm_generator , batch_size : int ,
121- output_len : int ):
122- """Verify greedy equality on a tiny model and large batch size.
123- """
124- run_greedy_equality_correctness_test (baseline_llm_generator ,
125- test_llm_generator ,
126- batch_size ,
127- max_output_len = output_len ,
128- force_output_len = True )
129-
130-
13187@pytest .mark .parametrize (
13288 "common_llm_kwargs" ,
13389 [{
@@ -198,15 +154,15 @@ def test_spec_decode_e2e_greedy_correctness_with_preemption(
198154 "ngram_prompt_lookup_max" : 3 ,
199155 }
200156 # Try a range of common k, as well as large speculation.
201- for k in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 63 ]
157+ for k in [1 , 3 , 5 , 7 , 10 , 63 ]
202158 ] + [
203159 {
204160 "speculative_model" : "[ngram]" ,
205161 "num_speculative_tokens" : k ,
206162 "ngram_prompt_lookup_max" : 1 ,
207163 }
208164 # Try a range of common k, as well as large speculation.
209- for k in [1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 63 ]
165+ for k in [1 , 3 , 5 , 7 , 10 , 63 ]
210166 ])
211167@pytest .mark .parametrize ("batch_size" , [2 ])
212168@pytest .mark .parametrize (
0 commit comments