Skip to content

Commit

Permalink
Add batch inferencing support for GPT2LMHeadModel (#7552)
Browse files Browse the repository at this point in the history
* Add support for gpt2 batch inferencing

* add test

* remove typo

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
  • Loading branch information
cccntu and patrickvonplaten authored Oct 14, 2020
1 parent 0c64b18 commit 121dd43
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
13 changes: 13 additions & 0 deletions src/transformers/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,10 +701,23 @@ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)

if attention_mask is not None and position_ids is None:
# create postion_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}

@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
Expand Down
45 changes: 45 additions & 0 deletions tests/test_modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
GPT2ForSequenceClassification,
GPT2LMHeadModel,
GPT2Model,
GPT2Tokenizer,
)


Expand Down Expand Up @@ -425,6 +426,50 @@ def test_gpt2_gradient_checkpointing(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs(gradient_checkpointing=True)
self.model_tester.create_and_check_forward_and_backwards(*config_and_inputs)

@slow
def test_batch_generation(self):
model = GPT2LMHeadModel.from_pretrained("gpt2")
model.to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

tokenizer.padding_side = "left"

# Define PAD Token = EOS Token = 50256
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
]

inputs = tokenizer(sentences, return_tensors="pt", padding=True)

torch.manual_seed(0)
outputs = model.generate(
input_ids=inputs["input_ids"].to(torch_device),
attention_mask=inputs["attention_mask"].to(torch_device),
)

inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device)
output_non_padded = model.generate(input_ids=inputs_non_padded)

num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().cpu().item()
inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device)
output_padded = model.generate(input_ids=inputs_padded, max_length=model.config.max_length - num_paddings)

batch_out_sentence = tokenizer.batch_decode(outputs, skip_special_tokens=True)
non_padded_sentence = tokenizer.decode(output_non_padded[0], skip_special_tokens=True)
padded_sentence = tokenizer.decode(output_padded[0], skip_special_tokens=True)

expected_output_sentence = [
"Hello, my dog is a little bit of a mess. I'm not sure if he's going",
"Today, I'm going to be doing a lot of research on this. I",
]
self.assertListEqual(expected_output_sentence, batch_out_sentence)
self.assertListEqual(expected_output_sentence, [non_padded_sentence, padded_sentence])

@slow
def test_model_from_pretrained(self):
for model_name in GPT2_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
Expand Down

0 comments on commit 121dd43

Please sign in to comment.