diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 26b6cb39682233..a69f89167d7a06 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -846,7 +846,8 @@ def generate( encoder_inputs = input_ids input_ids = torch.full( (effective_batch_size * num_beams, 1), - eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this + # eos_token_id, # TODO (PVP): to check if this is the only solution -> quite hacky to do this + bos_token_id, dtype=torch.long, device=next(self.parameters()).device, ) @@ -1079,10 +1080,10 @@ def _generate_beam_search( next_token_logits = next_token_logits / temperature scores = F.log_softmax(next_token_logits, dim=-1) # (batch_size * num_beams, vocab_size) - if ( - self.config.is_encoder_decoder and do_sample is False - ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here - scores = self.prepare_scores_for_generation(scores, cur_len, max_length) + # if ( + # self.config.is_encoder_decoder and do_sample is False + # ): # TODO(PVP) to be refactored later - do we need this boolean flag here? Also Only add for beam_search or also for no_beam_search? The prepare scores fn is ugly here + # scores = self.prepare_scores_for_generation(scores, cur_len, max_length) # set eos token prob to zero if min_length is not reached if eos_token_ids is not None and cur_len < min_length: @@ -1271,9 +1272,9 @@ def _generate_beam_search( assert (len(hypo) == max_length for hypo in best) decoded = torch.stack(best).type(torch.long).to(next(self.parameters()).device) - if self.config.is_encoder_decoder: - # do not return first token - return decoded[:, 1:] + # if self.config.is_encoder_decoder: + # do not return first token + # return decoded[:, 1:] return decoded # force one of token_ids to be generated by setting prob of all other tokens to 0. diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 6f1676ddce900b..d377c1194c20bd 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -214,6 +214,9 @@ def _get_config_and_data(self, output_past=False): decoder_ffn_dim=32, max_position_embeddings=48, output_past=output_past, + eos_token_id=2, + pad_token_id=1, + bos_token_id=0, ) return config, input_ids, batch_size @@ -468,7 +471,8 @@ def test_cnn_summarization_same_as_fairseq_hard(self): attention_mask=dct["attention_mask"].to(torch_device), num_beams=4, length_penalty=2.0, - max_length=max_length + 2, + # max_length=max_length + 2, + max_length=max_length + 1, min_length=min_length + 1, no_repeat_ngram_size=3, do_sample=False,