Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Commit 121dd43 changes DialoGPT generation behavior #8032

Closed
2 of 4 tasks
abisee opened this issue Oct 25, 2020 · 5 comments
Closed
2 of 4 tasks

Commit 121dd43 changes DialoGPT generation behavior #8032

abisee opened this issue Oct 25, 2020 · 5 comments
Assignees
Labels

Comments

@abisee
Copy link
Contributor

abisee commented Oct 25, 2020

Environment info

  • transformers version: 3.3.1
  • Platform: Linux-4.4.0-127-generic-x86_64-with-debian-stretch-sid
  • Python version: 3.7.3
  • PyTorch version (GPU?): 1.6.0+cu101 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Using GPU in script?: yes (1 TITAN-XP)
  • Using distributed or parallel set-up in script?: no

Who can help

@cccntu @patrickvonplaten @LysandreJik

Information

Model I am using (Bert, XLNet ...): DialoGPT-large

The problem arises when using:

  • the official example scripts: (give details below)
  • my own modified scripts: (give details below)

The tasks I am working on is:

  • an official GLUE/SQUaD task: (give the name)
  • my own task or dataset: (give details below)

To reproduce

Steps to reproduce the behavior:

  1. Checkout 121dd43.

  2. Run the DialoGPT "How to use" code given here, but change DialoGPT-medium to DialoGPT-large:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-large")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-large")

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)

    # pretty print last ouput tokens from bot
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
  1. For the user's first utterance, type "Hello, how are you?". I get this output:
>> User:Hello, how are you?
DialoGPT: 're you a fan of the show?

Note: this problem is still present in the current version of master (5148f43).

Expected behavior

With the previous commit, 0c64b18, I get this output:

>> User:Hello, how are you?
DialoGPT: I'm good, you?

Possible cause

The issue seems to be related to the <|endoftext|> token, which is used at the end of every utterance. This is being regarded as a padding token, and thus it's attention-masked, which also seems to affect the position ids.

@cccntu
Copy link
Contributor

cccntu commented Oct 26, 2020

Hi @abisee , sorry for the inconvenience.

Even though you did not pass in attention mask, it is created here: (first 2 lines)

if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
attention_mask = input_ids.ne(pad_token_id).long()
elif attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
# set pad_token_id to eos_token_id if not set. Important that this is done after
# attention_mask is created
if pad_token_id is None and eos_token_id is not None:
logger.warning(
"Setting `pad_token_id` to {} (first `eos_token_id`) to generate sequence".format(eos_token_id)
)
pad_token_id = eos_token_id

changing this
chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
to
chat_history_ids = model.generate(bot_input_ids, max_length=1000, )
seems to solve the problem. (the pad_token_id will still be set to tokenizer.eos_token_id, but after attention_mask is set to all ones)

Here is how the bug can happen:
If someone tries to

  • use eos_token_id in sentences
  • and also sets pad_token_id=eos_token_id
  • and attention mask is created like this (using positions of pad_token_id). (there is no problem when using tokenizer to create attention mask)

Don't have a better solution for now, will think about it.
@patrickvonplaten @LysandreJik What do you think?
Maybe generate() should not create attention mask for users, but this can break other code, too.

@abisee
Copy link
Contributor Author

abisee commented Oct 27, 2020

Thanks for the response @cccntu!

My understanding is that both GPT2 and DialoGPT were trained without a pad token; i.e. neither model has a pad token embedding. In that case, why does the DialoGPT example code contain pad_token_id=tokenizer.eos_token_id? What's the purpose of doing this, if a pad token isn't needed for generation, and the EOS token was never used as a pad token during training?

@abisee
Copy link
Contributor Author

abisee commented Oct 27, 2020

For generation, it seems that attention masks are created automatically (if there's an assigned pad token that appears in the input). See GenerationMixin.generate():

# create attention mask if necessary
        # TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
        if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
            attention_mask = input_ids.ne(pad_token_id).long()
        elif attention_mask is None:
            attention_mask = input_ids.new_ones(input_ids.shape)

However for training (at least for GPT2 models), as far as I can tell, the attention mask is not created automatically, even if there's an assigned pad token that appears in the input.

This seems like an unexpected discrepancy, and another reason to put the attention mask creation in the model's forward as proposed by @thomwolf in PR 3140.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Oct 27, 2020

That's a super interesting issue! Thanks for posting it here!

So in short, in order to be able to do batch_generation with GPT2 (or Beam Search), we have to use some kind of token as the pad_token_id in case one batch finishes early. We decided a while back that for GPT2 we will just use the eos_token_id as the pad_token_id in this case.

Just as you guys noticed the problem lies in generate() automatically creating the attention_mask and falsely assuming the eos_token_id is a pad_token_id .

IMO, it was a mistake to automatically create the attention_mask in generate() as it could lead to unexpected problems such as those!

I'm currently doing a big generate() refactor and in this refactor the problem should be solved (see comment in PR linked below).

I hope that I'll be able to merge the PR in ~1 week.

@stale
Copy link

stale bot commented Jan 2, 2021

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jan 2, 2021
@stale stale bot closed this as completed Jan 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants