22
22
_PAD_SLOT_ID = 0 # FIXME(woosuk)
23
23
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
24
24
_ENABLE_TOP_P = False
25
+ # FIXME(woosuk): A temporary hack to support `n > 1`.
26
+ # This can significantly affect the performance if too large.
27
+ _MAX_NUM_SAMPLES = 128
25
28
26
29
27
30
class TPUModelRunner :
@@ -143,8 +146,9 @@ def _dummy_run(
143
146
p = torch .ones ((batch_size , ), dtype = torch .float32 , device = self .device )
144
147
145
148
# Dummy run.
149
+ num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
146
150
self .model (token_ids , position_ids , kv_caches , attn_metadata ,
147
- input_lens , t , p )
151
+ input_lens , t , p , num_samples )
148
152
149
153
def warmup_model (
150
154
self ,
@@ -268,14 +272,11 @@ def _prepare_decode(
268
272
input_positions : List [List [int ]] = []
269
273
slot_mapping : List [List [int ]] = []
270
274
context_lens : List [int ] = []
271
- num_seq_groups = len (seq_group_metadata_list )
272
- batch_size = _get_padded_batch_size (num_seq_groups )
273
275
274
- for i , seq_group_metadata in enumerate (seq_group_metadata_list ):
276
+ batch_idx = 0
277
+ for seq_group_metadata in seq_group_metadata_list :
275
278
assert not seq_group_metadata .is_prompt
276
-
277
279
seq_ids = list (seq_group_metadata .seq_data .keys ())
278
-
279
280
for seq_id in seq_ids :
280
281
seq_data = seq_group_metadata .seq_data [seq_id ]
281
282
generation_token = seq_data .get_last_token_id ()
@@ -288,14 +289,16 @@ def _prepare_decode(
288
289
289
290
assert seq_group_metadata .block_tables is not None
290
291
block_table = seq_group_metadata .block_tables [seq_id ]
291
- self .block_tables [i , :len (block_table )] = block_table
292
+ self .block_tables [batch_idx , :len (block_table )] = block_table
293
+ batch_idx += 1
292
294
293
295
block_number = block_table [position // self .block_size ]
294
296
block_offset = position % self .block_size
295
297
slot = block_number * self .block_size + block_offset
296
298
slot_mapping .append ([slot ])
297
299
298
- num_paddings = batch_size - num_seq_groups
300
+ batch_size = _get_padded_batch_size (batch_idx )
301
+ num_paddings = batch_size - batch_idx
299
302
input_tokens = input_tokens + [[0 ]] * num_paddings
300
303
input_positions = input_positions + [[0 ]] * num_paddings
301
304
slot_mapping = slot_mapping + [[_PAD_SLOT_ID ]] * num_paddings
@@ -333,14 +336,13 @@ def _prepare_sample(
333
336
self ,
334
337
seq_group_metadata_list : List [SequenceGroupMetadata ],
335
338
padded_batch_size : int ,
336
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
339
+ ) -> Tuple [torch .Tensor , torch .Tensor , List [ int ] ]:
337
340
assert len (seq_group_metadata_list ) > 0
338
341
t = []
339
342
p = []
343
+ best_of = []
340
344
for seq_group_metadata in seq_group_metadata_list :
341
- assert seq_group_metadata .sampling_params is not None
342
345
sampling_params = seq_group_metadata .sampling_params
343
-
344
346
# NOTE(woosuk): Here we mimic argmax sampling by applying a very
345
347
# low temperature. This is not accurate.
346
348
t .append (sampling_params .temperature
@@ -354,10 +356,11 @@ def _prepare_sample(
354
356
raise NotImplementedError (
355
357
"Top-k sampling is currently disabled for the TPU backend "
356
358
"due to performance issues." )
357
- if sampling_params .best_of > 1 :
359
+ if sampling_params .best_of > _MAX_NUM_SAMPLES :
358
360
raise NotImplementedError (
359
- "best_of > 1 is not currently supported by the TPU "
361
+ f"Best of > { _MAX_NUM_SAMPLES } is not supported by the TPU "
360
362
"backend." )
363
+ best_of .append (sampling_params .best_of )
361
364
if sampling_params .use_beam_search :
362
365
raise NotImplementedError (
363
366
"Beam search is not supported by the TPU backend." )
@@ -369,13 +372,19 @@ def _prepare_sample(
369
372
"prompt_logprobs is not currently supported by the TPU "
370
373
"backend." )
371
374
372
- num_paddings = padded_batch_size - len (seq_group_metadata_list )
375
+ # Repeat the sampling params if the seq group has multiple seqs.
376
+ num_seqs = len (seq_group_metadata .seq_data )
377
+ t += [t [- 1 ]] * (num_seqs - 1 )
378
+ p += [p [- 1 ]] * (num_seqs - 1 )
379
+ best_of += [best_of [- 1 ]] * (num_seqs - 1 )
380
+
381
+ num_paddings = padded_batch_size - len (t )
373
382
t += [1.0 ] * num_paddings
374
383
p += [1.0 ] * num_paddings
375
384
376
385
t = torch .tensor (t , dtype = torch .float32 , device = self .device )
377
386
p = torch .tensor (p , dtype = torch .float32 , device = self .device )
378
- return t , p
387
+ return t , p , best_of
379
388
380
389
def _execute_model (
381
390
self ,
@@ -392,28 +401,41 @@ def _execute_model(
392
401
else :
393
402
inputs = self ._prepare_decode (seq_group_metadata_list )
394
403
padded_batch_size = inputs [0 ].shape [0 ]
395
- t , p = self ._prepare_sample (seq_group_metadata_list , padded_batch_size )
404
+ t , p , best_of = self ._prepare_sample (seq_group_metadata_list ,
405
+ padded_batch_size )
406
+ num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
396
407
397
408
# Execute the model.
398
409
next_token_ids = self .model (inputs [0 ], inputs [1 ], kv_caches ,
399
- * inputs [2 :], t , p )
410
+ * inputs [2 :], t , p , num_samples )
400
411
# Retrieve the outputs to CPU.
401
412
next_token_ids = next_token_ids .cpu ().tolist ()
402
413
403
414
# NOTE(woosuk): Minimal code to construct the sampler outputs.
404
415
# The TPU backend does not reuse the sampler, since the TPU backend
405
416
# does not support the advanced sampling parameters such as logprobs.
406
- i = 0
417
+ zero_logprob = Logprob (0.0 )
418
+ batch_idx = 0
407
419
sampler_outputs = []
408
420
for seq_group_metadata in seq_group_metadata_list :
409
421
seq_outputs = []
410
422
seq_ids = list (seq_group_metadata .seq_data .keys ())
411
- for seq_id in seq_ids :
412
- next_token_id = next_token_ids [i ]
413
- seq_outputs .append (
414
- SequenceOutput (seq_id , next_token_id ,
415
- {next_token_id : Logprob (0.0 )}))
416
- i += 1
423
+ if is_prompt :
424
+ assert len (seq_ids ) == 1
425
+ seq_id = seq_ids [0 ]
426
+ for i in range (best_of [batch_idx ]):
427
+ next_token_id = next_token_ids [batch_idx ][i ]
428
+ seq_outputs .append (
429
+ SequenceOutput (seq_id , next_token_id ,
430
+ {next_token_id : zero_logprob }))
431
+ batch_idx += 1
432
+ else :
433
+ for seq_id in seq_ids :
434
+ next_token_id = next_token_ids [batch_idx ][0 ]
435
+ seq_outputs .append (
436
+ SequenceOutput (seq_id , next_token_id ,
437
+ {next_token_id : zero_logprob }))
438
+ batch_idx += 1
417
439
sampler_outputs .append (
418
440
CompletionSequenceGroupOutput (seq_outputs , None ))
419
441
return sampler_outputs
@@ -458,6 +480,7 @@ def forward(
458
480
input_lens : torch .Tensor ,
459
481
t : torch .Tensor ,
460
482
p : torch .Tensor ,
483
+ num_samples : int ,
461
484
) -> torch .Tensor :
462
485
"""Executes the forward pass of the model and samples the next token.
463
486
@@ -520,8 +543,9 @@ def forward(
520
543
if _ENABLE_TOP_P :
521
544
logits = _apply_top_p (logits , p .unsqueeze (dim = 1 ))
522
545
probs = torch .softmax (logits , dim = - 1 , dtype = torch .float32 )
523
- # FIXME(woosuk): best_of > 1 is not supported.
524
- next_token_ids = torch .multinomial (probs , num_samples = 1 ).squeeze (dim = 1 )
546
+ next_token_ids = torch .multinomial (probs ,
547
+ num_samples ,
548
+ replacement = True )
525
549
return next_token_ids
526
550
527
551
0 commit comments