@@ -52,6 +52,19 @@ def get_draft_token_ids(batch_size: int, k: int, vocab_size: int,
52
52
return draft_token_ids
53
53
54
54
55
+ def get_acceptance_sampler (
56
+ posterior_threshold : float = 0.03 ,
57
+ posterior_alpha : float = 0.9 ,
58
+ disable_bonus_tokens : bool = False ,
59
+ strict_mode : bool = False ,
60
+ ) -> TypicalAcceptanceSampler :
61
+ """
62
+ Initializes and returns a TypicalAcceptanceSampler.
63
+ """
64
+ return TypicalAcceptanceSampler (posterior_threshold , posterior_alpha ,
65
+ disable_bonus_tokens , strict_mode )
66
+
67
+
55
68
@pytest .mark .parametrize ("k" , list (range (1 , 6 )))
56
69
@pytest .mark .parametrize ("vocab_size" , [30_000 , 50_000 ])
57
70
@pytest .mark .parametrize ("batch_size" , list (range (1 , 32 )))
@@ -64,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
64
77
different combinations of k, vocab_size, batch_size and num devices.
65
78
"""
66
79
torch .set_default_device (device )
67
- typical_acceptance_sampler = TypicalAcceptanceSampler ()
80
+ typical_acceptance_sampler = get_acceptance_sampler ()
68
81
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
69
82
target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
70
83
bonus_token_ids = torch .randint (low = 0 ,
@@ -76,7 +89,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
76
89
size = (batch_size , k ),
77
90
dtype = torch .int64 )
78
91
# Verify that sampling succeeds for all cases.
79
- typical_acceptance_sampler (target_probs , bonus_token_ids , draft_token_ids )
92
+ typical_acceptance_sampler (target_probs ,
93
+ bonus_token_ids ,
94
+ draft_probs = None ,
95
+ draft_token_ids = draft_token_ids )
80
96
81
97
82
98
@pytest .mark .parametrize ("above_or_below_vocab_range" , ["above" , "below" ])
@@ -94,7 +110,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
94
110
batch_size = 5
95
111
vocab_size = 30_000
96
112
torch .set_default_device (device )
97
- typical_acceptance_sampler = TypicalAcceptanceSampler (strict_mode = True )
113
+ typical_acceptance_sampler = get_acceptance_sampler (strict_mode = True )
98
114
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
99
115
target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
100
116
bonus_token_ids = torch .randint (low = 0 ,
@@ -125,8 +141,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
125
141
oob_token_ids [0 ][0 ] = rogue_token_id
126
142
127
143
with pytest .raises (AssertionError ):
128
- typical_acceptance_sampler (target_probs , bonus_token_ids ,
129
- draft_token_ids )
144
+ typical_acceptance_sampler (target_probs ,
145
+ bonus_token_ids ,
146
+ draft_probs = None ,
147
+ draft_token_ids = draft_token_ids )
130
148
131
149
132
150
@pytest .mark .parametrize ("seed" , list (range (10 )))
@@ -151,7 +169,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
151
169
batch_size = 5
152
170
vocab_size = 30_000
153
171
torch .set_default_device (device )
154
- typical_acceptance_sampler = TypicalAcceptanceSampler (
172
+ typical_acceptance_sampler = get_acceptance_sampler (
155
173
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
156
174
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
157
175
target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
@@ -163,9 +181,11 @@ def test_uniform_target_distribution_accepts_all_tokens(
163
181
high = vocab_size ,
164
182
size = (batch_size , 1 ),
165
183
dtype = torch .int64 )
166
- output_token_ids = typical_acceptance_sampler (target_probs ,
167
- bonus_token_ids ,
168
- draft_token_ids )
184
+ output_token_ids = typical_acceptance_sampler (
185
+ target_probs ,
186
+ bonus_token_ids ,
187
+ draft_probs = None ,
188
+ draft_token_ids = draft_token_ids )
169
189
# We are using a uniform target probability distribution.
170
190
# For a uniform distribution the entropy is very high and it
171
191
# should lead to all draft tokens being accepted. Verify that.
@@ -203,7 +223,7 @@ def test_temperature_zero_target_distribution(seed: int,
203
223
vocab_size = 30_000
204
224
torch .set_default_device (device )
205
225
206
- typical_acceptance_sampler = TypicalAcceptanceSampler (
226
+ typical_acceptance_sampler = get_acceptance_sampler (
207
227
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
208
228
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
209
229
# Simulate temperature 0 probability distribution for target probabilities
@@ -224,9 +244,11 @@ def test_temperature_zero_target_distribution(seed: int,
224
244
# 1.0 tokens in the target distribution we will reject all of them and
225
245
# fallback to the greedy sampling for selecting 1 token for each sequence.
226
246
# Verify the same.
227
- output_token_ids = typical_acceptance_sampler (target_probs ,
228
- bonus_token_ids ,
229
- draft_token_ids )
247
+ output_token_ids = typical_acceptance_sampler (
248
+ target_probs ,
249
+ bonus_token_ids ,
250
+ draft_probs = None ,
251
+ draft_token_ids = draft_token_ids )
230
252
assert output_token_ids .shape [0 ] == batch_size
231
253
assert output_token_ids .shape [1 ] == (k + 1 )
232
254
assert torch .all (output_token_ids [:, - 1 ] == - 1 )
@@ -261,7 +283,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
261
283
batch_size = 4
262
284
vocab_size = 30_000
263
285
torch .set_default_device (device )
264
- typical_acceptance_sampler = TypicalAcceptanceSampler (
286
+ typical_acceptance_sampler = get_acceptance_sampler (
265
287
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
266
288
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
267
289
# For sequences 0 and 2 set the distribution to a temperature
@@ -277,9 +299,11 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool,
277
299
high = vocab_size ,
278
300
size = (batch_size , 1 ),
279
301
dtype = torch .int64 )
280
- output_token_ids = typical_acceptance_sampler (target_probs ,
281
- bonus_token_ids ,
282
- draft_token_ids )
302
+ output_token_ids = typical_acceptance_sampler (
303
+ target_probs ,
304
+ bonus_token_ids ,
305
+ draft_probs = None ,
306
+ draft_token_ids = draft_token_ids )
283
307
# verify the shape of output_token_ids
284
308
assert output_token_ids .shape [0 ] == batch_size
285
309
assert output_token_ids .shape [1 ] == (k + 1 )
@@ -326,7 +350,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
326
350
batch_size = 1
327
351
vocab_size = 30_000
328
352
torch .set_default_device (device )
329
- typical_acceptance_sampler = TypicalAcceptanceSampler (
353
+ typical_acceptance_sampler = get_acceptance_sampler (
330
354
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
331
355
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
332
356
# Create a temperature zero target probability distribution and ensure
@@ -339,9 +363,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
339
363
high = vocab_size ,
340
364
size = (batch_size , 1 ),
341
365
dtype = torch .int64 )
342
- output_token_ids = typical_acceptance_sampler (target_probs ,
343
- bonus_token_ids ,
344
- draft_token_ids )
366
+ output_token_ids = typical_acceptance_sampler (
367
+ target_probs ,
368
+ bonus_token_ids ,
369
+ draft_probs = None ,
370
+ draft_token_ids = draft_token_ids )
345
371
assert output_token_ids .shape [0 ] == batch_size
346
372
assert output_token_ids .shape [1 ] == (k + 1 )
347
373
assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -357,9 +383,11 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool,
357
383
batch_size , k , vocab_size , zero_temperature_token_ids )
358
384
draft_token_ids = torch .cat (
359
385
(draft_token_ids [:, :2 ], draft_token_ids_to_replace [:, - 3 :]), dim = 1 )
360
- output_token_ids = typical_acceptance_sampler (target_probs ,
361
- bonus_token_ids ,
362
- draft_token_ids )
386
+ output_token_ids = typical_acceptance_sampler (
387
+ target_probs ,
388
+ bonus_token_ids ,
389
+ draft_probs = None ,
390
+ draft_token_ids = draft_token_ids )
363
391
assert output_token_ids .shape [0 ] == batch_size
364
392
assert output_token_ids .shape [1 ] == (k + 1 )
365
393
assert torch .all (output_token_ids [:, :2 ] == draft_token_ids [:, :2 ])
@@ -384,7 +412,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
384
412
batch_size = 1
385
413
vocab_size = 30_000
386
414
torch .set_default_device (device )
387
- typical_acceptance_sampler = TypicalAcceptanceSampler (
415
+ typical_acceptance_sampler = get_acceptance_sampler (
388
416
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
389
417
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
390
418
# Simulate temperature 0 probability distribution for target
@@ -402,9 +430,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
402
430
high = vocab_size ,
403
431
size = (batch_size , 1 ),
404
432
dtype = torch .int64 )
405
- output_token_ids = typical_acceptance_sampler (target_probs ,
406
- bonus_token_ids ,
407
- draft_token_ids )
433
+ output_token_ids = typical_acceptance_sampler (
434
+ target_probs ,
435
+ bonus_token_ids ,
436
+ draft_probs = None ,
437
+ draft_token_ids = draft_token_ids )
408
438
assert output_token_ids .shape [0 ] == batch_size
409
439
assert output_token_ids .shape [1 ] == (k + 1 )
410
440
assert torch .all (output_token_ids [:, 1 :- 1 ] == - 1 )
@@ -418,9 +448,11 @@ def test_accept_tokens_set_non_default_posteriors(seed: int,
418
448
posterior_threshold = 0.0 ,
419
449
posterior_alpha = 0.0 )
420
450
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
421
- output_token_ids = typical_acceptance_sampler (target_probs ,
422
- bonus_token_ids ,
423
- draft_token_ids )
451
+ output_token_ids = typical_acceptance_sampler (
452
+ target_probs ,
453
+ bonus_token_ids ,
454
+ draft_probs = None ,
455
+ draft_token_ids = draft_token_ids )
424
456
assert output_token_ids .shape [0 ] == batch_size
425
457
assert output_token_ids .shape [1 ] == (k + 1 )
426
458
assert torch .all (output_token_ids [:, 0 :- 1 ] == draft_token_ids )
@@ -451,7 +483,7 @@ def test_replacement_token_ids(seed: int, disable_bonus_tokens: bool,
451
483
batch_size = 5
452
484
vocab_size = 30_000
453
485
torch .set_default_device (device )
454
- typical_acceptance_sampler = TypicalAcceptanceSampler (
486
+ typical_acceptance_sampler = get_acceptance_sampler (
455
487
strict_mode = True , disable_bonus_tokens = disable_bonus_tokens )
456
488
typical_acceptance_sampler .init_gpu_tensors (rank = 0 )
457
489
target_probs = torch .rand (batch_size , k , vocab_size , dtype = torch .float32 )
0 commit comments