2929# speculative model
3030SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
3131
32- # max. number of speculative tokens
32+ # max. number of speculative tokens: this corresponds to
33+ # n_predict in the config.json of the speculator model.
3334MAX_SPEC_TOKENS = 5
3435
3536# precision
5051
5152 # Precision
5253 "dtype" : PRECISION ,
53- }])
54- @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [
55- {
54+
55+ # Main model
5656 "model" : MAIN_MODEL ,
57- },
58- ])
57+ }])
58+ @ pytest . mark . parametrize ( "per_test_common_llm_kwargs" , [{} ])
5959@pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
6060@pytest .mark .parametrize ("test_llm_kwargs" , [
6161 {
6262 "speculative_model" : SPEC_MODEL ,
63- "num_speculative_tokens" : MAX_SPEC_TOKENS ,
6463 },
6564])
6665@pytest .mark .parametrize ("output_len" , [
@@ -94,17 +93,15 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
9493
9594 # Precision
9695 "dtype" : PRECISION ,
97- }])
98- @pytest .mark .parametrize ("per_test_common_llm_kwargs" , [
99- {
96+
97+ # Main model
10098 "model" : MAIN_MODEL ,
101- },
102- ])
99+ }])
100+ @ pytest . mark . parametrize ( "per_test_common_llm_kwargs" , [{} ])
103101@pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
104102@pytest .mark .parametrize ("test_llm_kwargs" , [
105103 {
106104 "speculative_model" : SPEC_MODEL ,
107- "num_speculative_tokens" : MAX_SPEC_TOKENS ,
108105 },
109106])
110107@pytest .mark .parametrize (
@@ -132,8 +129,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
132129@pytest .mark .parametrize (
133130 "common_llm_kwargs" ,
134131 [{
135- "model" : MAIN_MODEL ,
136-
137132 # Skip cuda graph recording for fast test.
138133 "enforce_eager" : True ,
139134
@@ -142,6 +137,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
142137
143138 # Precision
144139 "dtype" : PRECISION ,
140+
141+ # Main model
142+ "model" : MAIN_MODEL ,
145143 }])
146144@pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
147145@pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
@@ -178,8 +176,6 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
178176@pytest .mark .parametrize (
179177 "common_llm_kwargs" ,
180178 [{
181- "model" : MAIN_MODEL ,
182-
183179 # Skip cuda graph recording for fast test.
184180 "enforce_eager" : True ,
185181
@@ -188,13 +184,15 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
188184
189185 # Precision
190186 "dtype" : PRECISION ,
187+
188+ # Main model
189+ "model" : MAIN_MODEL ,
191190 }])
192191@pytest .mark .parametrize ("per_test_common_llm_kwargs" , [{}])
193192@pytest .mark .parametrize ("baseline_llm_kwargs" , [{}])
194193@pytest .mark .parametrize ("test_llm_kwargs" ,
195194 [{
196195 "speculative_model" : SPEC_MODEL ,
197- "num_speculative_tokens" : MAX_SPEC_TOKENS ,
198196 "speculative_disable_by_batch_size" : 4
199197 }])
200198@pytest .mark .parametrize ("batch_size" , [1 , 5 ])
0 commit comments