Skip to content

Commit ad42323

Browse files
committed
Addressed comments
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
1 parent a96d115 commit ad42323

File tree

1 file changed

+16
-18
lines changed

1 file changed

+16
-18
lines changed

tests/spec_decode/e2e/test_mlp_correctness.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
# speculative model
3030
SPEC_MODEL = "ibm-granite/granite-3b-code-instruct-accelerator"
3131

32-
# max. number of speculative tokens
32+
# max. number of speculative tokens: this corresponds to
33+
# n_predict in the config.json of the speculator model.
3334
MAX_SPEC_TOKENS = 5
3435

3536
# precision
@@ -50,17 +51,15 @@
5051
5152
# Precision
5253
"dtype": PRECISION,
53-
}])
54-
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
55-
{
54+
55+
# Main model
5656
"model": MAIN_MODEL,
57-
},
58-
])
57+
}])
58+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
5959
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
6060
@pytest.mark.parametrize("test_llm_kwargs", [
6161
{
6262
"speculative_model": SPEC_MODEL,
63-
"num_speculative_tokens": MAX_SPEC_TOKENS,
6463
},
6564
])
6665
@pytest.mark.parametrize("output_len", [
@@ -94,17 +93,15 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
9493
9594
# Precision
9695
"dtype": PRECISION,
97-
}])
98-
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
99-
{
96+
97+
# Main model
10098
"model": MAIN_MODEL,
101-
},
102-
])
99+
}])
100+
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
103101
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
104102
@pytest.mark.parametrize("test_llm_kwargs", [
105103
{
106104
"speculative_model": SPEC_MODEL,
107-
"num_speculative_tokens": MAX_SPEC_TOKENS,
108105
},
109106
])
110107
@pytest.mark.parametrize(
@@ -132,8 +129,6 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
132129
@pytest.mark.parametrize(
133130
"common_llm_kwargs",
134131
[{
135-
"model": MAIN_MODEL,
136-
137132
# Skip cuda graph recording for fast test.
138133
"enforce_eager": True,
139134
@@ -142,6 +137,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
142137
143138
# Precision
144139
"dtype": PRECISION,
140+
141+
# Main model
142+
"model": MAIN_MODEL,
145143
}])
146144
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
147145
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@@ -178,8 +176,6 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
178176
@pytest.mark.parametrize(
179177
"common_llm_kwargs",
180178
[{
181-
"model": MAIN_MODEL,
182-
183179
# Skip cuda graph recording for fast test.
184180
"enforce_eager": True,
185181
@@ -188,13 +184,15 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
188184
189185
# Precision
190186
"dtype": PRECISION,
187+
188+
# Main model
189+
"model": MAIN_MODEL,
191190
}])
192191
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
193192
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
194193
@pytest.mark.parametrize("test_llm_kwargs",
195194
[{
196195
"speculative_model": SPEC_MODEL,
197-
"num_speculative_tokens": MAX_SPEC_TOKENS,
198196
"speculative_disable_by_batch_size": 4
199197
}])
200198
@pytest.mark.parametrize("batch_size", [1, 5])

0 commit comments

Comments
 (0)