@@ -95,6 +95,10 @@ def test_perfect_match(rejection_sampler):
95
95
device = logits .device )
96
96
assert torch .equal (output , expected )
97
97
98
+ assert rejection_sampler .stats .num_draft_tokens == 3
99
+ assert rejection_sampler .stats .num_accepted_tokens == 3
100
+ assert rejection_sampler .stats .num_emitted_tokens == 4
101
+
98
102
99
103
def test_early_mismatch (rejection_sampler ):
100
104
"""Test when there's an early mismatch in tokens"""
@@ -122,6 +126,10 @@ def test_early_mismatch(rejection_sampler):
122
126
)
123
127
assert torch .equal (output , expected )
124
128
129
+ assert rejection_sampler .stats .num_draft_tokens == 3
130
+ assert rejection_sampler .stats .num_accepted_tokens == 1
131
+ assert rejection_sampler .stats .num_emitted_tokens == 2
132
+
125
133
126
134
def test_multiple_sequences (rejection_sampler ):
127
135
"""Test handling multiple sequences of speculated tokens"""
@@ -148,6 +156,10 @@ def test_multiple_sequences(rejection_sampler):
148
156
device = logits .device )
149
157
assert torch .equal (output , expected )
150
158
159
+ assert rejection_sampler .stats .num_draft_tokens == 3
160
+ assert rejection_sampler .stats .num_accepted_tokens == 3
161
+ assert rejection_sampler .stats .num_emitted_tokens == 5
162
+
151
163
152
164
def test_single_token_sequence (rejection_sampler ):
153
165
"""Test handling sequences with single token"""
@@ -171,6 +183,10 @@ def test_single_token_sequence(rejection_sampler):
171
183
expected = torch .tensor ([[1 , 2 ]], dtype = torch .int , device = logits .device )
172
184
assert torch .equal (output , expected )
173
185
186
+ assert rejection_sampler .stats .num_draft_tokens == 1
187
+ assert rejection_sampler .stats .num_accepted_tokens == 1
188
+ assert rejection_sampler .stats .num_emitted_tokens == 2
189
+
174
190
175
191
def test_empty_sequence (rejection_sampler ):
176
192
"""Test handling empty sequence of speculated tokens"""
@@ -194,6 +210,10 @@ def test_empty_sequence(rejection_sampler):
194
210
expected = torch .tensor ([[5 ]], dtype = torch .int , device = logits .device )
195
211
assert torch .equal (output , expected )
196
212
213
+ assert rejection_sampler .stats .num_draft_tokens == 0
214
+ assert rejection_sampler .stats .num_accepted_tokens == 0
215
+ assert rejection_sampler .stats .num_emitted_tokens == 1
216
+
197
217
198
218
def test_multiple_mismatches (rejection_sampler ):
199
219
"""Test handling multiple sequences with mismatches"""
@@ -223,17 +243,24 @@ def test_multiple_mismatches(rejection_sampler):
223
243
)
224
244
assert torch .equal (output , expected )
225
245
246
+ assert rejection_sampler .stats .num_draft_tokens == 6
247
+ assert rejection_sampler .stats .num_accepted_tokens == 3
248
+ assert rejection_sampler .stats .num_emitted_tokens == 5
249
+
226
250
227
251
@pytest .mark .parametrize (
228
- "spec_tokens,output_tokens,expected" ,
252
+ "spec_tokens,output_tokens,expected,expected_stats " ,
229
253
[
230
- ([[1 , 2 ]], [[1 , 2 , 3 ]], [[1 , 2 , 3 ]]), # Perfect match with bonus
231
- ([[1 ]], [[2 , 3 ]], [[2 , PLACEHOLDER_TOKEN_ID ]]), # First mismatch
232
- ([[1 , 2 ], [3 , 4 ]], [[1 , 5 , 6 ], [3 , 4 , 7 ]],
233
- [[1 , 5 , PLACEHOLDER_TOKEN_ID ], [3 , 4 , 7 ]]), # Mixed matches
254
+ ([[1 , 2 ]], [[1 , 2 , 3 ]], [[1 , 2 , 3 ]],
255
+ (2 , 2 , 3 )), # Perfect match with bonus
256
+ ([[1 ]], [[2 , 3 ]], [[2 , PLACEHOLDER_TOKEN_ID ]],
257
+ (1 , 0 , 1 )), # First mismatch
258
+ ([[1 , 2 ], [3 , 4 ]], [[1 , 5 , 6 ], [3 , 4 , 7 ]
259
+ ], [[1 , 5 , PLACEHOLDER_TOKEN_ID ], [3 , 4 , 7 ]],
260
+ (4 , 3 , 5 )), # Mixed matches
234
261
])
235
262
def test_parametrized_cases (rejection_sampler , spec_tokens , output_tokens ,
236
- expected ):
263
+ expected , expected_stats ):
237
264
"""Parametrized test for various matching scenarios"""
238
265
metadata = create_sampling_metadata (all_greedy = True )
239
266
logits = create_logits_tensor (output_tokens )
@@ -254,6 +281,10 @@ def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
254
281
device = logits .device )
255
282
assert torch .equal (output , expected_tensor )
256
283
284
+ assert rejection_sampler .stats .num_draft_tokens == expected_stats [0 ]
285
+ assert rejection_sampler .stats .num_accepted_tokens == expected_stats [1 ]
286
+ assert rejection_sampler .stats .num_emitted_tokens == expected_stats [2 ]
287
+
257
288
258
289
########################### Tests for Random Sampling ###################
259
290
@pytest .mark .parametrize ("k" , [1 , 3 , 5 ])
@@ -314,6 +345,12 @@ def test_deterministic_when_seeded(
314
345
315
346
results .append (rep_result )
316
347
348
+ stats = rejection_sampler .stats .take ()
349
+ assert stats .num_draft_tokens == num_tokens
350
+ assert stats .num_emitted_tokens >= batch_size
351
+ assert (stats .num_emitted_tokens -
352
+ batch_size ) == stats .num_accepted_tokens
353
+
317
354
for i in range (batch_size ):
318
355
if seeded_mask [i ]:
319
356
for j in range (1 , n_rep ):
0 commit comments