Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add codegen unittests #3348

Merged
merged 4 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix codegen
  • Loading branch information
FrostML committed Sep 28, 2022
commit 648c01d6af79d2bbcefafc84f4bd03bd50fb4243
30 changes: 19 additions & 11 deletions paddlenlp/transformers/codegen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000**(paddle.arange(0, dim, 2) / dim))
sinusoid_inp = (paddle.einsum("i , j -> i j",
sinusoid_inp = (paddle.einsum("i,j->ij",
paddle.arange(seq_len, dtype="float32"),
inv_freq))
return paddle.sin(sinusoid_inp), paddle.cos(sinusoid_inp)
Expand Down Expand Up @@ -74,13 +74,10 @@ def __init__(self, embed_dim, rotary_dim, num_attention_heads,
max_positions, attn_pdrop, resid_pdrop):
super().__init__()

self.register_buffer(
"causal_mask",
paddle.tril(
paddle.ones((max_positions, max_positions),
dtype=paddle.get_default_dtype())).reshape(
(1, 1, max_positions, max_positions)),
)
self.causal_mask = paddle.tril(
paddle.ones((max_positions, max_positions),
dtype=paddle.get_default_dtype())).reshape(
(1, 1, max_positions, max_positions))

self.attn_dropout = nn.Dropout(attn_pdrop)
self.resid_dropout = nn.Dropout(resid_pdrop)
Expand Down Expand Up @@ -490,9 +487,9 @@ def forward(
"specified when generating attention_mask"
if batch_size == 1 and past_length != 0:
batch_size, seq_len = input_shape
attention_mask = paddle.ones(
[batch_size, 1, 1, seq_len + past_length],
dtype=paddle.get_default_dtype())
attention_mask = (
1.0 - paddle.ones([batch_size, 1, 1, seq_len + past_length],
dtype=paddle.get_default_dtype())) * -1e4
else:
attention_mask = paddle.cast(
input_ids == self.pad_token_id,
Expand All @@ -503,6 +500,12 @@ def forward(
attention_mask, axis=[1, 2]).astype(paddle.get_default_dtype())
attention_mask = (1.0 - attention_mask) * -1e4
attention_mask.stop_gradient = True
# TODO: CodeGen Attention Mask is TOO confusion.
# When it's 2D, it must be int and it's denoted by 1/0.
# When using model.generate() without providing attention mask
# or using 4D attention mask,
# the attention mask's dtype must be float and it's denoted by 0/-inf.
# Moreover, cannot support 3D attention mask.

inputs_embeds = self.wte(input_ids)
if token_type_ids is not None:
Expand Down Expand Up @@ -587,8 +590,12 @@ def prepare_faster_entry(self, kwargs):

def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
token_type_ids = kwargs.get("token_type_ids", None)

if cache:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
Expand All @@ -600,6 +607,7 @@ def prepare_inputs_for_generation(self, input_ids, cache=None, **kwargs):
"cache": cache,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}

def forward(self,
Expand Down
66 changes: 51 additions & 15 deletions tests/transformers/codegen/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

import datetime
import unittest
import numpy as np
import random

import paddle
from paddlenlp.transformers import (CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST,
Expand Down Expand Up @@ -78,6 +80,10 @@ def __init__(
self.eos_token_id = vocab_size - 1
self.pad_token_id = vocab_size - 1

paddle.seed(128)
np.random.seed(128)
random.seed(128)

def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length],
self.vocab_size,
Expand Down Expand Up @@ -382,6 +388,31 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin,
test_model_parallel = False
test_head_masking = False

# attention mask issue
def _get_input_ids_and_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common(
)

input_ids = inputs_dict[self.input_name]
attention_mask = paddle.zeros_like(input_ids, dtype=paddle.float32)

max_batch_size = 2
sequence_length = input_ids.shape[-1] // 2
input_ids = input_ids[:max_batch_size, :sequence_length]
attention_mask = attention_mask[:max_batch_size, :
sequence_length].unsqueeze([1, 2])

# generate max 3 tokens
max_length = 3

if config.get(
"eos_token_id",
None) is not None and config.get("pad_token_id", None) is None:
# hack to allow generate for models such as GPT2 as is done in `generate()`
config["pad_token_id"] = config["eos_token_id"]

return config, input_ids, attention_mask, max_length

# special case for DoubleHeads model
def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class)
Expand Down Expand Up @@ -415,8 +446,7 @@ def test_codegen_lm_head_model(self):

@slow
def test_batch_generation(self):
# tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
tokenizer = CodeGenTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
"Salesforce/codegen-350M-mono")
model = CodeGenForCausalLM.from_pretrained(
"Salesforce/codegen-350M-mono")
Expand Down Expand Up @@ -477,13 +507,17 @@ def test_model_from_pretrained(self):
def test_model_name_list(self):
pass

@slow
def test_auto_tokenizer(self):
for model_name in CODEGEN_PRETRAINED_MODEL_ARCHIVE_LIST:
tokenizer = AutoTokenizer.from_pretrained(model_name)


class CodeGenModelLanguageGenerationTest(unittest.TestCase):

@slow
def test_lm_generate_codegen(self):
# tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
tokenizer = CodeGenTokenizer.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
"Salesforce/codegen-350M-mono")
model = CodeGenForCausalLM.from_pretrained(
"Salesforce/codegen-350M-mono")
Expand All @@ -504,15 +538,12 @@ def test_lm_generate_codegen(self):

@slow
def test_codegen_sample(self):
# NOTE: Only codegen-350M-mono supports AutoTokenizer.
tokenizer = AutoTokenizer.from_pretrained(
"Salesforce/codegen-350M-mono")
# tokenizer = CodeGenTokenizer.from_pretrained("Salesforce/codegen-350M-mono")
model = CodeGenForCausalLM.from_pretrained(
"Salesforce/codegen-350M-mono")
model.eval()

# NOTE: PaddleNLP do not support token_type_ids.
tokenized = tokenizer("def hello_world():",
return_tensors="pd",
return_token_type_ids=True,
Expand All @@ -523,21 +554,26 @@ def test_codegen_sample(self):
top_k=1)
output_str = tokenizer.decode(output_ids[0], skip_special_tokens=True)

# token_type_ids = tokenized.token_type_ids
token_type_ids = tokenized.token_type_ids
output_seq, _ = model.generate(input_ids=input_ids,
decode_strategy="sampling",
top_k=1,
num_return_sequences=5)
# output_seq_tt, _ = model.generate(
# input_ids=input_ids, token_type_ids=token_type_ids, decode_strategy="sampling", top_k=1, num_return_sequences=5
# )
output_seq_tt, _ = model.generate(input_ids=input_ids,
token_type_ids=token_type_ids,
decode_strategy="sampling",
top_k=1,
num_return_sequences=5)
output_seq_strs = tokenizer.batch_decode(output_seq,
skip_special_tokens=True)
# output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt, skip_special_tokens=True)
output_seq_tt_strs = tokenizer.batch_decode(output_seq_tt,
skip_special_tokens=True)

EXPECTED_OUTPUT_STR = '\n print("Hello World")\n\nhello_world()\n\n#'

self.assertEqual(output_str, EXPECTED_OUTPUT_STR)
# self.assertTrue(
# all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))])
# ) # token_type_ids should change output
self.assertTrue(
all([
output_seq_strs[idx] != output_seq_tt_strs[idx]
for idx in range(len(output_seq_tt_strs))
])) # token_type_ids should change output