Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated ConversationalPipeline to work with encoder-decoder models #8207

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions src/transformers/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -2430,18 +2430,31 @@ def __call__(
**generate_kwargs,
)

cleaned_history = self._clean_padding_history(generated_responses)
if self.model.config.is_encoder_decoder:
if self.framework == "pt":
history = torch.cat((inputs["input_ids"], generated_responses[:, 1:]), 1)
elif self.framework == "tf":
history = tf.concat([inputs["input_ids"], generated_responses[:, 1:]], 1)
else:
history = generated_responses

history = self._clean_padding_history(history)
if self.model.config.is_encoder_decoder:
start_position = 1
else:
start_position = input_length

output = []
for conversation_index, conversation in enumerate(conversations):
conversation.mark_processed()
conversation.generated_responses.append(
self.tokenizer.decode(
cleaned_history[conversation_index][input_length:],
generated_responses[conversation_index][start_position:],
skip_special_tokens=True,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
conversation.set_history(cleaned_history[conversation_index])
conversation.set_history(history[conversation_index])
output.append(conversation)
if len(output) == 1:
return output[0]
Expand Down Expand Up @@ -2475,6 +2488,8 @@ def _clean_padding_history(self, generated_tensor) -> List[List[int]]:
is_previous_pad = False
for token in sequence:
if token == self.tokenizer.pad_token_id:
if self.tokenizer.pad_token_id != self.tokenizer.eos_token_id:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to this issue: #8032?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is not. The previous code from the conversational pipeline was assuming that eos_token_id is always equal to pad_token_id, and therefore only deleting eos_token_id starting from the 2nd consecutive occurence (these are the "padding" eos). Generally, eos_token_id is not pad_token_id and in that case the padding needs to be removed from the fist pad_token on.

The other issue seems to be related to the generation process itself (creation of incorrect attention masks) - this specific statement is a post-processing step after the generation finishes.

continue
if is_previous_pad:
continue
else:
Expand Down
54 changes: 51 additions & 3 deletions tests/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest

from transformers import Conversation, pipeline
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, Conversation, ConversationalPipeline, pipeline
from transformers.testing_utils import require_torch, slow, torch_device

from .test_pipelines_common import MonoInputPipelineCommonMixin
Expand All @@ -15,8 +15,9 @@ class ConversationalPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCas
large_models = ["microsoft/DialoGPT-medium"] # Models tested with the @slow decorator
invalid_inputs = ["Hi there!", Conversation()]

def _test_pipeline(self, nlp):
# e overide the default test method to check that the output is a `Conversation` object
def _test_pipeline(
self, nlp
): # override the default test method to check that the output is a `Conversation` object
self.assertIsNotNone(nlp)

# We need to recreate conversation for successive tests to pass as
Expand Down Expand Up @@ -95,3 +96,50 @@ def test_integration_torch_conversation_truncated_history(self):
self.assertEqual(len(result.generated_responses), 2)
self.assertEqual(result.past_user_inputs[1], "Is it an action movie?")
self.assertEqual(result.generated_responses[1], "It's a comedy.")

@require_torch
@slow
def test_integration_torch_conversation_encoder_decoder(self):
# When
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-90M")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/blenderbot-90M")
nlp = ConversationalPipeline(model=model, tokenizer=tokenizer, device=DEFAULT_DEVICE_NUM)

conversation_1 = Conversation("My name is Sarah and I live in London")
conversation_2 = Conversation("Going to the movies tonight, What movie would you recommend? ")
# Then
self.assertEqual(len(conversation_1.past_user_inputs), 0)
self.assertEqual(len(conversation_2.past_user_inputs), 0)
# When
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
self.assertEqual(result, [conversation_1, conversation_2])
self.assertEqual(len(result[0].past_user_inputs), 1)
self.assertEqual(len(result[1].past_user_inputs), 1)
self.assertEqual(len(result[0].generated_responses), 1)
self.assertEqual(len(result[1].generated_responses), 1)
self.assertEqual(result[0].past_user_inputs[0], "My name is Sarah and I live in London")
self.assertEqual(
result[0].generated_responses[0],
"hi sarah, i live in london as well. do you have any plans for the weekend?",
)
self.assertEqual(
result[1].past_user_inputs[0], "Going to the movies tonight, What movie would you recommend? "
)
self.assertEqual(
result[1].generated_responses[0], "i don't know... i'm not really sure. what movie are you going to see?"
)
# When
conversation_1.add_user_input("Not yet, what about you?")
conversation_2.add_user_input("What's your name?")
result = nlp([conversation_1, conversation_2], do_sample=False, max_length=1000)
# Then
self.assertEqual(result, [conversation_1, conversation_2])
self.assertEqual(len(result[0].past_user_inputs), 2)
self.assertEqual(len(result[1].past_user_inputs), 2)
self.assertEqual(len(result[0].generated_responses), 2)
self.assertEqual(len(result[1].generated_responses), 2)
self.assertEqual(result[0].past_user_inputs[1], "Not yet, what about you?")
self.assertEqual(result[0].generated_responses[1], "i don't have any plans yet. i'm not sure what to do yet.")
self.assertEqual(result[1].past_user_inputs[1], "What's your name?")
self.assertEqual(result[1].generated_responses[1], "i don't have a name, but i'm going to see a horror movie.")