|
21 | 21 | import paddle.nn.functional as F |
22 | 22 | from paddle.fluid.data_feeder import convert_dtype |
23 | 23 | from paddle.fluid.layers.utils import map_structure |
| 24 | +from paddlenlp.utils.log import logger |
24 | 25 |
|
25 | 26 | __all__ = ["GenerationMixin"] |
26 | 27 |
|
@@ -170,7 +171,7 @@ def process(self, |
170 | 171 | # add to generated hypotheses if end of sentence |
171 | 172 | if (eos_token_id is not None) and ( |
172 | 173 | next_token.numpy().item() == eos_token_id): |
173 | | - # If beam_token does not belong to top num_beams tokens, |
| 174 | + # If beam_token does not belong to top num_beams tokens, |
174 | 175 | # it should not be added |
175 | 176 | is_beam_token_worse_than_top_num_beams = ( |
176 | 177 | beam_token_rank >= self.group_size) |
@@ -357,10 +358,10 @@ def expand_inputs_for_generation(input_ids, |
357 | 358 | def update_model_kwargs_for_generation(outputs, |
358 | 359 | model_kwargs, |
359 | 360 | is_encoder_decoder=False): |
360 | | - # Update the model inputs during generation. |
361 | | - # Note that If `token_type_ids` and `attention_mask` in `model_kwargs` |
362 | | - # and they contain pad value, the result vectors updated by this method |
363 | | - # may be different from expected. In this case, you need to rewrite the |
| 361 | + # Update the model inputs during generation. |
| 362 | + # Note that If `token_type_ids` and `attention_mask` in `model_kwargs` |
| 363 | + # and they contain pad value, the result vectors updated by this method |
| 364 | + # may be different from expected. In this case, you need to rewrite the |
364 | 365 | # method. |
365 | 366 |
|
366 | 367 | # update cache |
@@ -433,11 +434,37 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs): |
433 | 434 | return {"input_ids": input_ids} |
434 | 435 |
|
435 | 436 | def adjust_logits_during_generation(self, logits): |
436 | | - # Implement in subclasses for custom behavior to adjust the logits in |
| 437 | + # Implement in subclasses for custom behavior to adjust the logits in |
437 | 438 | # the generate method. |
438 | 439 |
|
439 | 440 | return logits |
440 | 441 |
|
| 442 | + def prepare_faster_entry(self, kwargs): |
| 443 | + pass |
| 444 | + |
| 445 | + def _convert_to_faster(self, kwargs): |
| 446 | + # try general convert |
| 447 | + pass |
| 448 | + |
| 449 | + def _build_faster(self, kwargs): |
| 450 | + self._faster_entry = False |
| 451 | + |
| 452 | + # common check for FasterTransformer |
| 453 | + if kwargs['min_length'] != 0: |
| 454 | + # not support for min_length yet in the faster version |
| 455 | + return |
| 456 | + if kwargs['repetition_penalty'] != 0: |
| 457 | + # not support for repetition_penalty yet in the faster version |
| 458 | + return |
| 459 | + if kwargs['temperature'] != 1: |
| 460 | + # not support for temperature yet in the faster version |
| 461 | + return |
| 462 | + |
| 463 | + # 1. custom convert |
| 464 | + if not self.prepare_faster_entry(kwargs): |
| 465 | + # 2. try general convert |
| 466 | + self._convert_to_faster(kwargs) |
| 467 | + |
441 | 468 | @paddle.no_grad() |
442 | 469 | def generate(self, |
443 | 470 | input_ids=None, |
@@ -610,6 +637,42 @@ def generate(self, |
610 | 637 | print(response) |
611 | 638 | # ['是的', '嗯嗯'] |
612 | 639 | """ |
| 640 | + # Switch to FasterTransformer automatically if supporting. |
| 641 | + if getattr(self, '_faster_entry', None) is not False: |
| 642 | + # TODO(guosheng): need better way to avoid recursive building |
| 643 | + if not self.__class__.__module__.endswith('faster_transformer'): |
| 644 | + args = locals() |
| 645 | + args.pop('self') |
| 646 | + args.pop("__class__", None) |
| 647 | + try: |
| 648 | + if not hasattr(self, '_faster_entry'): |
| 649 | + self._build_faster(args) |
| 650 | + if self._faster_entry: |
| 651 | + model_kwargs = args.pop('model_kwargs') |
| 652 | + # transpose to batch major to be consistent with original results |
| 653 | + output_ids = self._faster_entry(**args, **model_kwargs) |
| 654 | + if len(output_ids.shape) == 2: # sampling |
| 655 | + output_ids = paddle.transpose(output_ids, [1, 0]) |
| 656 | + else: # beam search |
| 657 | + output_ids = paddle.transpose(output_ids, [1, 2, 0]) |
| 658 | + output_ids = output_ids[:, : |
| 659 | + num_return_sequences].reshape( |
| 660 | + [ |
| 661 | + -1, |
| 662 | + output_ids.shape[-1] |
| 663 | + ]) |
| 664 | + # append dummy scores to be consistent with original results |
| 665 | + scores = None |
| 666 | + return output_ids, scores |
| 667 | + else: |
| 668 | + # TODO(guosheng): Maybe we can report the unsupported |
| 669 | + # reasons to help users enable FasterTransformer when not |
| 670 | + # supporting. |
| 671 | + pass |
| 672 | + except Exception: |
| 673 | + logger.warning( |
| 674 | + "FasterTransformer is not available, " |
| 675 | + "and the original version would be used instead.") |
613 | 676 |
|
614 | 677 | # params check |
615 | 678 | bos_token_id = bos_token_id if bos_token_id is not None else getattr( |
@@ -778,7 +841,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep): |
778 | 841 | sorted_indices = paddle.argsort(probs, descending=True) |
779 | 842 | cumulative_probs = paddle.cumsum(sorted_probs, axis=-1) |
780 | 843 |
|
781 | | - # Remove tokens with cumulative probs above the top_p, But keep at |
| 844 | + # Remove tokens with cumulative probs above the top_p, But keep at |
782 | 845 | # least min_tokens_to_keep tokens |
783 | 846 | sorted_indices_to_remove = cumulative_probs > top_p |
784 | 847 | if min_tokens_to_keep > 1: |
|
0 commit comments