|
28 | 28 |
|
29 | 29 | logger = init_logger(__name__)
|
30 | 30 |
|
31 |
| -_PAD_SLOT_ID = -1 # NOTE(woosuk): In PyTorch XLA, index -1 is ignored. |
| 31 | +# Here we utilize the behavior that out-of-bound index is ignored. |
| 32 | +# FIXME(woosuk): Find a more reliable way to prevent possible bugs. |
| 33 | +_PAD_SLOT_ID = 1_000_000_000 |
32 | 34 | # FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
33 | 35 | _ENABLE_TOP_P = False
|
34 | 36 | # FIXME(woosuk): A temporary hack to support `n > 1`.
|
@@ -414,10 +416,7 @@ def _prepare_sample(
|
414 | 416 | best_of = []
|
415 | 417 | for seq_group_metadata in seq_group_metadata_list:
|
416 | 418 | sampling_params = seq_group_metadata.sampling_params
|
417 |
| - # NOTE(woosuk): Here we mimic argmax sampling by applying a very |
418 |
| - # low temperature. This is not accurate. |
419 |
| - t.append(sampling_params.temperature |
420 |
| - if sampling_params.temperature >= 1e-5 else 1e-5) |
| 419 | + t.append(sampling_params.temperature) |
421 | 420 | if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
422 | 421 | raise NotImplementedError(
|
423 | 422 | "Top-p sampling is currently disabled for the TPU backend "
|
@@ -678,13 +677,23 @@ def forward(
|
678 | 677 | hidden_states = hidden_states.flatten(0, 1)
|
679 | 678 | logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
680 | 679 |
|
681 |
| - logits = logits / t.unsqueeze(dim=1) |
| 680 | + # Argmax sampling. |
| 681 | + argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True) |
| 682 | + argmax_token_ids = argmax_token_ids.repeat(1, num_samples) |
| 683 | + |
| 684 | + # Zero temperature means greedy decoding. Avoid division by zero. |
| 685 | + nonzero_t = torch.where(t != 0, t, 1.0) |
| 686 | + logits = logits / nonzero_t.unsqueeze(dim=1) |
682 | 687 | if _ENABLE_TOP_P:
|
683 | 688 | logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
| 689 | + |
| 690 | + # Random sampling. |
684 | 691 | probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
685 |
| - next_token_ids = torch.multinomial(probs, |
686 |
| - num_samples, |
687 |
| - replacement=True) |
| 692 | + sampled_token_ids = torch.multinomial(probs, |
| 693 | + num_samples, |
| 694 | + replacement=True) |
| 695 | + next_token_ids = torch.where(t != 0, sampled_token_ids, |
| 696 | + argmax_token_ids) |
688 | 697 | return next_token_ids
|
689 | 698 |
|
690 | 699 |
|
|
0 commit comments