Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete merge Seq-2-Seq generation into default generation #3225

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions examples/summarization/bart/evaluate_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@ def generate_summaries(lns, out_file, batch_size=8, device=DEFAULT_DEVICE):
fout = Path(out_file).open("w")
model = BartForConditionalGeneration.from_pretrained("bart-large-cnn", output_past=True,).to(device)
tokenizer = BartTokenizer.from_pretrained("bart-large")

max_length = 140
min_length = 55

for batch in tqdm(list(chunks(lns, batch_size))):
dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True)
summaries = model.generate(
input_ids=dct["input_ids"].to(device),
attention_mask=dct["attention_mask"].to(device),
num_beams=4,
length_penalty=2.0,
max_length=142, # +2 from original because we start at step=1 and stop before max_length
min_length=56, # +1 from original because we start at step=1
max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length
min_length=min_length + 1, # +1 from original because we start at step=1
no_repeat_ngram_size=3,
early_stopping=True,
do_sample=False,
decoder_start_token_id=model.config.eos_token_ids[0],
)
dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries]
for hypothesis in dec:
Expand Down
13 changes: 10 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def generate(
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Expand Down Expand Up @@ -739,6 +740,10 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)

if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
Expand All @@ -765,6 +770,9 @@ def generate(
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
Expand Down Expand Up @@ -845,7 +853,7 @@ def generate(
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
bos_token_id,
decoder_start_token_id, # TODO: see whether this is the best result
dtype=torch.long,
device=next(self.parameters()).device,
)
Expand Down Expand Up @@ -1082,7 +1090,7 @@ def _generate_beam_search(

scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if self.config.is_encoder_decoder and do_sample is False:
# TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# TODO: maybe give better naming
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)

# set eos token prob to zero if min_length is not reached
Expand Down Expand Up @@ -1276,7 +1284,6 @@ def _generate_beam_search(
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
return decoded

Expand Down
11 changes: 8 additions & 3 deletions tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(
self.hidden_dropout_prob = 0.1
self.attention_probs_dropout_prob = 0.1
self.max_position_embeddings = 20
self.eos_token_id = 2
self.eos_token_ids = [2]
self.pad_token_id = 1
self.bos_token_id = 0
torch.manual_seed(0)
Expand All @@ -82,7 +82,7 @@ def prepare_config_and_inputs_for_common(self):
dropout=self.hidden_dropout_prob,
attention_dropout=self.attention_probs_dropout_prob,
max_position_embeddings=self.max_position_embeddings,
eos_token_ids=[self.eos_token_id],
eos_token_ids=[2],
bos_token_id=self.bos_token_id,
pad_token_id=self.pad_token_id,
)
Expand Down Expand Up @@ -432,7 +432,11 @@ def test_cnn_summarization_same_as_fairseq_easy(self):
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_ids[0],
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
Expand Down Expand Up @@ -477,6 +481,7 @@ def test_cnn_summarization_same_as_fairseq_hard(self):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_ids[0],
)

decoded = [
Expand Down