Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions paddlenlp/ops/faster_transformer/sample/bart_decoding_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,12 @@ def do_predict(args):
input_ids, mem_seq_lens = prepare_input(tokenizer, sentences, pad_id)

# Define model
faster_bart = FasterBART(
model=model,
decoding_strategy=args.decoding_strategy,
decoding_lib=args.decoding_lib,
use_fp16_decoding=args.use_fp16_decoding)
faster_bart = model
# faster_bart = FasterBART(
# model=model,
# decoding_strategy=args.decoding_strategy,
# decoding_lib=args.decoding_lib,
# use_fp16_decoding=args.use_fp16_decoding)

# Set evaluate mode
faster_bart.eval()
Expand Down
28 changes: 28 additions & 0 deletions paddlenlp/transformers/bart/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,34 @@ def get_encoder(self):
def get_decoder(self):
return self.bart.get_decoder()

def prepare_faster_entry(self, kwargs):
from paddlenlp.ops import FasterBART
decoding_strategy = kwargs.get('decode_strategy')
model_kwargs = kwargs['model_kwargs']
use_fp16_decoding = model_kwargs.get('use_fp16_decoding', False)
# TODO(guosheng): Currently, beam_search_v2 in FasterTransformer uses
# t2t beam search which has some difference with beam search in generation
# api on finish queue addition and early-stop criterion, and it seems
# lead to poor performance on bart cnn-sum model, thus we disable it temporarily.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是之后会把加速版的的beam_search改成hf的么,应该不用禁掉beam_search吧

if decoding_strategy == 'beam_search':
return False
# Some checks on kwargs. For example, FasterBART needs `mem_seq_lens` as
# one input while BART not, thus check whether `mem_seq_lens` in kwargs.
if model_kwargs.get('mem_seq_lens', None) is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mem_seq_lens的获取可以移到FasterBart中,保证generate接受的参数统一

return False
# Assume no args change among multi-turns run to convert parameters only
# once. Additionaly, use some converted args as default values instead of
# converting args to allow overriding.
# TODO(guosheng): maybe use weakref for the model in faster model
self._faster_entry = partial(
FasterBART(
self,
decoding_strategy=decoding_strategy,
use_fp16_decoding=use_fp16_decoding).generate,
alpha=kwargs.get('length_penalty'),
rel_len=False)
return self._faster_entry

def forward(self,
input_ids,
attention_mask=None,
Expand Down
77 changes: 70 additions & 7 deletions paddlenlp/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import paddle.nn.functional as F
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layers.utils import map_structure
from paddlenlp.utils.log import logger

__all__ = ["GenerationMixin"]

Expand Down Expand Up @@ -170,7 +171,7 @@ def process(self,
# add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (
next_token.numpy().item() == eos_token_id):
# If beam_token does not belong to top num_beams tokens,
# If beam_token does not belong to top num_beams tokens,
# it should not be added
is_beam_token_worse_than_top_num_beams = (
beam_token_rank >= self.group_size)
Expand Down Expand Up @@ -357,10 +358,10 @@ def expand_inputs_for_generation(input_ids,
def update_model_kwargs_for_generation(outputs,
model_kwargs,
is_encoder_decoder=False):
# Update the model inputs during generation.
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
# and they contain pad value, the result vectors updated by this method
# may be different from expected. In this case, you need to rewrite the
# Update the model inputs during generation.
# Note that If `token_type_ids` and `attention_mask` in `model_kwargs`
# and they contain pad value, the result vectors updated by this method
# may be different from expected. In this case, you need to rewrite the
# method.

# update cache
Expand Down Expand Up @@ -433,11 +434,37 @@ def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}

def adjust_logits_during_generation(self, logits):
# Implement in subclasses for custom behavior to adjust the logits in
# Implement in subclasses for custom behavior to adjust the logits in
# the generate method.

return logits

def prepare_faster_entry(self, kwargs):
pass

def _convert_to_faster(self, kwargs):
# try general convert
pass

def _build_faster(self, kwargs):
self._faster_entry = False

# common check for FasterTransformer
if kwargs['min_length'] != 0:
# not support for min_length yet in the faster version
return
if kwargs['repetition_penalty'] != 0:
# not support for repetition_penalty yet in the faster version
return
if kwargs['temperature'] != 1:
# not support for temperature yet in the faster version
return

# 1. custom convert
if not self.prepare_faster_entry(kwargs):
# 2. try general convert
self._convert_to_faster(kwargs)

@paddle.no_grad()
def generate(self,
input_ids=None,
Expand Down Expand Up @@ -610,6 +637,42 @@ def generate(self,
print(response)
# ['是的', '嗯嗯']
"""
# Switch to FasterTransformer automatically if supporting.
if getattr(self, '_faster_entry', None) is not False:
# TODO(guosheng): need better way to avoid recursive building
if not self.__class__.__module__.endswith('faster_transformer'):
args = locals()
args.pop('self')
args.pop("__class__", None)
try:
if not hasattr(self, '_faster_entry'):
self._build_faster(args)
if self._faster_entry:
model_kwargs = args.pop('model_kwargs')
# transpose to batch major to be consistent with original results
output_ids = self._faster_entry(**args, **model_kwargs)
if len(output_ids.shape) == 2: # sampling
output_ids = paddle.transpose(output_ids, [1, 0])
else: # beam search
output_ids = paddle.transpose(output_ids, [1, 2, 0])
output_ids = output_ids[:, :
num_return_sequences].reshape(
[
-1,
output_ids.shape[-1]
])
# append dummy scores to be consistent with original results
scores = None
return output_ids, scores
else:
# TODO(guosheng): Maybe we can report the unsupported
# reasons to help users enable FasterTransformer when not
# supporting.
pass
except Exception:
logger.warning(
"FasterTransformer is not available, "
"and the original version would be used instead.")

# params check
bos_token_id = bos_token_id if bos_token_id is not None else getattr(
Expand Down Expand Up @@ -778,7 +841,7 @@ def TopPProcess(probs, top_p, min_tokens_to_keep):
sorted_indices = paddle.argsort(probs, descending=True)
cumulative_probs = paddle.cumsum(sorted_probs, axis=-1)

# Remove tokens with cumulative probs above the top_p, But keep at
# Remove tokens with cumulative probs above the top_p, But keep at
# least min_tokens_to_keep tokens
sorted_indices_to_remove = cumulative_probs > top_p
if min_tokens_to_keep > 1:
Expand Down