Skip to content

Commit

Permalink
add draft version of propsoed changes for ROGUE score
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Mar 9, 2020
1 parent a9a885a commit 4f9d1f2
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
17 changes: 9 additions & 8 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,8 @@ def generate(
encoder_inputs = input_ids
input_ids = torch.full(
(effective_batch_size * num_beams, 1),
eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
# eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this
bos_token_id,
dtype=torch.long,
device=next(self.parameters()).device,
)
Expand Down Expand Up @@ -1079,10 +1080,10 @@ def _generate_beam_search(
next_token_logits = next_token_logits / temperature

scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size)
if (
self.config.is_encoder_decoder and do_sample is False
): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
scores = self.prepare_scores_for_generation(scores, cur_len, max_length)
# if (
# self.config.is_encoder_decoder and do_sample is False
# ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here
# scores = self.prepare_scores_for_generation(scores, cur_len, max_length)

# set eos token prob to zero if min_length is not reached
if eos_token_ids is not None and cur_len < min_length:
Expand Down Expand Up @@ -1271,9 +1272,9 @@ def _generate_beam_search(
assert (len(hypo) == max_length for hypo in best)
decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device)

if self.config.is_encoder_decoder:
# do not return first <EOS> token
return decoded[:, 1:]
# if self.config.is_encoder_decoder:
# do not return first <EOS> token
# return decoded[:, 1:]
return decoded

# force one of token_ids to be generated by setting prob of all other tokens to 0.
Expand Down
6 changes: 5 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def _get_config_and_data(self, output_past=False):
decoder_ffn_dim=32,
max_position_embeddings=48,
output_past=output_past,
eos_token_id=2,
pad_token_id=1,
bos_token_id=0,
)
return config, input_ids, batch_size

Expand Down Expand Up @@ -468,7 +471,8 @@ def test_cnn_summarization_same_as_fairseq_hard(self):
attention_mask=dct["attention_mask"].to(torch_device),
num_beams=4,
length_penalty=2.0,
max_length=max_length + 2,
# max_length=max_length + 2,
max_length=max_length + 1,
min_length=min_length + 1,
no_repeat_ngram_size=3,
do_sample=False,
Expand Down

0 comments on commit 4f9d1f2

Please sign in to comment.