Skip to content

Commit

Permalink
re-add eos token to get good bart results
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Mar 12, 2020
1 parent c111601 commit 6047f46
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
10 changes: 9 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ def generate(
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
Expand Down Expand Up @@ -739,6 +740,10 @@ def generate(
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
# TODO: think about how to make this cleaner
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.bos_token_id
)

if input_ids is not None:
batch_size = input_ids.shape[0] # overriden by the input batch_size
Expand All @@ -765,6 +770,9 @@ def generate(
assert (eos_token_ids is None) or (
isinstance(eos_token_ids, (list, tuple)) and ((isinstance(e, int) and e >= 0) for e in eos_token_ids)
), "`eos_token_ids` should be a positive integer or a list/tuple of positive integers."
assert (
decoder_start_token_id is not None or self.config.is_encoder_decoder is False
), "`decoder_start_token_id` has to be defined if model is encoder-decoder model"
assert length_penalty > 0, "`length_penalty` should be strictly positive."
assert (
isinstance(no_repeat_ngram_size, int) and no_repeat_ngram_size >= 0
Expand Down Expand Up @@ -845,7 +853,7 @@ def generate(
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
bos_token_id, # TODO: wait for results of Bart CNN summarization
decoder_start_token_id, # TODO: see whether this is the best result
dtype=torch.long,
device=next(self.parameters()).device,
)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,11 @@ def test_cnn_summarization_same_as_fairseq_easy(self):
tokens = tok.encode(text, return_tensors="pt").to(torch_device)
extra_len = 20
gen_tokens = hf.generate(
tokens, num_beams=4, max_length=extra_len + 2, do_sample=False
tokens,
num_beams=4,
max_length=extra_len + 2,
do_sample=False,
decoder_start_token_id=hf.config.eos_token_id,
) # repetition_penalty=10.,
expected_result = "<s>The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday."
generated = [tok.decode(g,) for g in gen_tokens]
Expand Down Expand Up @@ -477,6 +481,7 @@ def test_cnn_summarization_same_as_fairseq_hard(self):
no_repeat_ngram_size=3,
do_sample=False,
early_stopping=True,
decoder_start_token_id=hf.config.eos_token_id,
)

decoded = [
Expand Down

0 comments on commit 6047f46

Please sign in to comment.