Skip to content

Commit 20acd16

Browse files
Integrate FT into generation api (#1154)
* Add the first draft for integrating FT into generation api. * Add try-catch in FT using of generation api. * Refine FasterTransformer integration into generation api. * Update some checks in FT integration. Co-authored-by: smallv0221 <33639025+smallv0221@users.noreply.github.com>
1 parent 27e0c34 commit 20acd16

File tree

3 files changed

+104
-12
lines changed

3 files changed

+104
-12
lines changed

paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,12 @@ def do_predict(args):
138138
input_ids, mem_seq_lens = prepare_input(tokenizer, sentences, pad_id)
139139

140140
# Define model
141-
faster_bart = FasterBART(
142-
model=model,
143-
decoding_strategy=args.decoding_strategy,
144-
decoding_lib=args.decoding_lib,
145-
use_fp16_decoding=args.use_fp16_decoding)
141+
faster_bart = model
142+
# faster_bart = FasterBART(
143+
# model=model,
144+
# decoding_strategy=args.decoding_strategy,
145+
# decoding_lib=args.decoding_lib,
146+
# use_fp16_decoding=args.use_fp16_decoding)
146147

147148
# Set evaluate mode
148149
faster_bart.eval()

paddlenlp/transformers/bart/modeling.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,34 @@ def get_encoder(self):
759759
def get_decoder(self):
760760
return self.bart.get_decoder()
761761

762+
def prepare_faster_entry(self, kwargs):
763+
from paddlenlp.ops import FasterBART
764+
decoding_strategy = kwargs.get('decode_strategy')
765+
model_kwargs = kwargs['model_kwargs']
766+
use_fp16_decoding = model_kwargs.get('use_fp16_decoding', False)
767+
# TODO(guosheng): Currently, beam_search_v2 in FasterTransformer uses
768+
# t2t beam search which has some difference with beam search in generation
769+
# api on finish queue addition and early-stop criterion, and it seems
770+
# lead to poor performance on bart cnn-sum model, thus we disable it temporarily.
771+
if decoding_strategy == 'beam_search':
772+
return False
773+
# Some checks on kwargs. For example, FasterBART needs `mem_seq_lens` as
774+
# one input while BART not, thus check whether `mem_seq_lens` in kwargs.
775+
if model_kwargs.get('mem_seq_lens', None) is None:
776+
return False
777+
# Assume no args change among multi-turns run to convert parameters only
778+
# once. Additionaly, use some converted args as default values instead of
779+
# converting args to allow overriding.
780+
# TODO(guosheng): maybe use weakref for the model in faster model
781+
self._faster_entry = partial(
782+
FasterBART(
783+
self,
784+
decoding_strategy=decoding_strategy,
785+
use_fp16_decoding=use_fp16_decoding).generate,
786+
alpha=kwargs.get('length_penalty'),
787+
rel_len=False)
788+
return self._faster_entry
789+
762790
def forward(self,
763791
input_ids,
764792
attention_mask=None,

paddlenlp/transformers/generation_utils.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import paddle.nn.functional as F
2222
from paddle.fluid.data_feeder import convert_dtype
2323
from paddle.fluid.layers.utils import map_structure
24+
from paddlenlp.utils.log import logger
2425

2526
__all__ = ["GenerationMixin"]
2627

@@ -170,7 +171,7 @@ def process(self,
170171
# add to generated hypotheses if end of sentence
171172
if (eos_token_id is not None) and (
172173
next_token.numpy().item() == eos_token_id):
173-
# If beam_token does not belong to top num_beams tokens,
174+
# If beam_token does not belong to top num_beams tokens,
174175
# it should not be added
175176
is_beam_token_worse_than_top_num_beams = (
176177
beam_token_rank >= self.group_size)
@@ -357,10 +358,10 @@ def expand_inputs_for_generation(input_ids,
357358
def update_model_kwargs_for_generation(outputs,
358359
model_kwargs,
359360
is_encoder_decoder=False):
360-
# Update the model inputs during generation.
361-
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
362-
# and they contain pad value, the result vectors updated by this method
363-
# may be different from expected. In this case, you need to rewrite the
361+
# Update the model inputs during generation.
362+
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
363+
# and they contain pad value, the result vectors updated by this method
364+
# may be different from expected. In this case, you need to rewrite the
364365
# method.
365366

366367
# update cache
@@ -433,11 +434,37 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs):
433434
return {"input_ids": input_ids}
434435

435436
def adjust_logits_during_generation(self, logits):
436-
# Implement in subclasses for custom behavior to adjust the logits in
437+
# Implement in subclasses for custom behavior to adjust the logits in
437438
# the generate method.
438439

439440
return logits
440441

442+
def prepare_faster_entry(self, kwargs):
443+
pass
444+
445+
def _convert_to_faster(self, kwargs):
446+
# try general convert
447+
pass
448+
449+
def _build_faster(self, kwargs):
450+
self._faster_entry = False
451+
452+
# common check for FasterTransformer
453+
if kwargs['min_length'] != 0:
454+
# not support for min_length yet in the faster version
455+
return
456+
if kwargs['repetition_penalty'] != 0:
457+
# not support for repetition_penalty yet in the faster version
458+
return
459+
if kwargs['temperature'] != 1:
460+
# not support for temperature yet in the faster version
461+
return
462+
463+
# 1. custom convert
464+
if not self.prepare_faster_entry(kwargs):
465+
# 2. try general convert
466+
self._convert_to_faster(kwargs)
467+
441468
@paddle.no_grad()
442469
def generate(self,
443470
input_ids=None,
@@ -610,6 +637,42 @@ def generate(self,
610637
print(response)
611638
# ['是的', '嗯嗯']
612639
"""
640+
# Switch to FasterTransformer automatically if supporting.
641+
if getattr(self, '_faster_entry', None) is not False:
642+
# TODO(guosheng): need better way to avoid recursive building
643+
if not self.__class__.__module__.endswith('faster_transformer'):
644+
args = locals()
645+
args.pop('self')
646+
args.pop("__class__", None)
647+
try:
648+
if not hasattr(self, '_faster_entry'):
649+
self._build_faster(args)
650+
if self._faster_entry:
651+
model_kwargs = args.pop('model_kwargs')
652+
# transpose to batch major to be consistent with original results
653+
output_ids = self._faster_entry(**args, **model_kwargs)
654+
if len(output_ids.shape) == 2: # sampling
655+
output_ids = paddle.transpose(output_ids, [1, 0])
656+
else: # beam search
657+
output_ids = paddle.transpose(output_ids, [1, 2, 0])
658+
output_ids = output_ids[:, :
659+
num_return_sequences].reshape(
660+
[
661+
-1,
662+
output_ids.shape[-1]
663+
])
664+
# append dummy scores to be consistent with original results
665+
scores = None
666+
return output_ids, scores
667+
else:
668+
# TODO(guosheng): Maybe we can report the unsupported
669+
# reasons to help users enable FasterTransformer when not
670+
# supporting.
671+
pass
672+
except Exception:
673+
logger.warning(
674+
"FasterTransformer is not available, "
675+
"and the original version would be used instead.")
613676

614677
# params check
615678
bos_token_id = bos_token_id if bos_token_id is not None else getattr(
@@ -778,7 +841,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
778841
sorted_indices = paddle.argsort(probs, descending=True)
779842
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)
780843

781-
# Remove tokens with cumulative probs above the top_p, But keep at
844+
# Remove tokens with cumulative probs above the top_p, But keep at
782845
# least min_tokens_to_keep tokens
783846
sorted_indices_to_remove = cumulative_probs > top_p
784847
if min_tokens_to_keep > 1:

0 commit comments

Comments
 (0)