|
28 | 28 |
|
29 | 29 | logger = init_logger(__name__) |
30 | 30 |
|
31 | | -_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored. |
| 31 | +# Here we utilize the behavior that out-of-bound index is ignored. |
| 32 | +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. |
| 33 | +_PAD_SLOT_ID = 1_000_000_000 |
32 | 34 | # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow. |
33 | 35 | _ENABLE_TOP_P = False |
34 | 36 | # FIXME(woosuk): A temporary hack to support `n > 1`. |
@@ -414,10 +416,7 @@ def _prepare_sample( |
414 | 416 | best_of = [] |
415 | 417 | for seq_group_metadata in seq_group_metadata_list: |
416 | 418 | sampling_params = seq_group_metadata.sampling_params |
417 | | - # NOTE(woosuk): Here we mimic argmax sampling by applying a very |
418 | | - # low temperature. This is not accurate. |
419 | | - t.append(sampling_params.temperature |
420 | | - if sampling_params.temperature >= 1e-5 else 1e-5) |
| 419 | + t.append(sampling_params.temperature) |
421 | 420 | if sampling_params.top_p != 1 and not _ENABLE_TOP_P: |
422 | 421 | raise NotImplementedError( |
423 | 422 | "Top-p sampling is currently disabled for the TPU backend " |
@@ -678,13 +677,23 @@ def forward( |
678 | 677 | hidden_states = hidden_states.flatten(0, 1) |
679 | 678 | logits = self.model.compute_logits(hidden_states, sampling_metadata) |
680 | 679 |
|
681 | | - logits = logits / t.unsqueeze(dim=1) |
| 680 | + # Argmax sampling. |
| 681 | + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) |
| 682 | + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) |
| 683 | + |
| 684 | + # Zero temperature means greedy decoding. Avoid division by zero. |
| 685 | + nonzero_t = torch.where(t != 0, t, 1.0) |
| 686 | + logits = logits / nonzero_t.unsqueeze(dim=1) |
682 | 687 | if _ENABLE_TOP_P: |
683 | 688 | logits = _apply_top_p(logits, p.unsqueeze(dim=1)) |
| 689 | + |
| 690 | + # Random sampling. |
684 | 691 | probs = torch.softmax(logits, dim=-1, dtype=torch.float32) |
685 | | - next_token_ids = torch.multinomial(probs, |
686 | | - num_samples, |
687 | | - replacement=True) |
| 692 | + sampled_token_ids = torch.multinomial(probs, |
| 693 | + num_samples, |
| 694 | + replacement=True) |
| 695 | + next_token_ids = torch.where(t != 0, sampled_token_ids, |
| 696 | + argmax_token_ids) |
688 | 697 | return next_token_ids |
689 | 698 |
|
690 | 699 |
|
|
0 commit comments