Skip to content

Commit 298d2e8

Browse files
committed
[test] added transformers models to test model zoo
1 parent 1216d1e commit 298d2e8

File tree

12 files changed

+339
-193
lines changed

12 files changed

+339
-193
lines changed

tests/kit/model_zoo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import diffusers, timm
1+
from . import diffusers, timm, transformers
22
from .registry import model_zoo
33

44
__all__ = ['model_zoo']
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .albert import *
2+
from .bert import *
3+
from .gpt import *
4+
from .opt import *
5+
from .t5 import *
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import transformers
3+
4+
from ..registry import ModelAttribute, model_zoo
5+
6+
# ===============================
7+
# Register single-sentence ALBERT
8+
# ===============================
9+
BATCH_SIZE = 2
10+
SEQ_LENGTH = 16
11+
12+
13+
def data_gen_fn():
14+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
15+
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
16+
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
17+
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
18+
19+
20+
output_transform_fn = lambda x: x
21+
22+
config = transformers.AlbertConfig(embedding_size=128,
23+
hidden_size=128,
24+
num_hidden_layers=2,
25+
num_attention_heads=4,
26+
intermediate_size=256)
27+
28+
model_zoo.register(name='transformers_albert',
29+
model_fn=lambda: transformers.AlbertModel(config),
30+
data_gen_fn=data_gen_fn,
31+
output_transform_fn=output_transform_fn,
32+
model_attribute=ModelAttribute(has_control_flow=True))
33+
model_zoo.register(name='transformers_albert_for_pretraining',
34+
model_fn=lambda: transformers.AlbertForPreTraining(config),
35+
data_gen_fn=data_gen_fn,
36+
output_transform_fn=output_transform_fn,
37+
model_attribute=ModelAttribute(has_control_flow=True))
38+
model_zoo.register(name='transformers_albert_for_masked_lm',
39+
model_fn=lambda: transformers.AlbertForMaskedLM(config),
40+
data_gen_fn=data_gen_fn,
41+
output_transform_fn=output_transform_fn,
42+
model_attribute=ModelAttribute(has_control_flow=True))
43+
model_zoo.register(name='transformers_albert_for_sequence_classification',
44+
model_fn=lambda: transformers.AlbertForSequenceClassification(config),
45+
data_gen_fn=data_gen_fn,
46+
output_transform_fn=output_transform_fn,
47+
model_attribute=ModelAttribute(has_control_flow=True))
48+
model_zoo.register(name='transformers_albert_for_token_classification',
49+
model_fn=lambda: transformers.AlbertForTokenClassification(config),
50+
data_gen_fn=data_gen_fn,
51+
output_transform_fn=output_transform_fn,
52+
model_attribute=ModelAttribute(has_control_flow=True))
53+
54+
# ===============================
55+
# Register multi-sentence ALBERT
56+
# ===============================
57+
58+
59+
def data_gen_for_qa():
60+
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
61+
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
62+
inputs = tokenizer(question, text, return_tensors="pt")
63+
return inputs
64+
65+
66+
def data_gen_for_mcq():
67+
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
68+
choice0 = "It is eaten with a fork and a knife."
69+
choice1 = "It is eaten while held in the hand."
70+
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
71+
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
72+
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
73+
return encoding
74+
75+
76+
model_zoo.register(name='transformers_albert_for_question_answering',
77+
model_fn=lambda: transformers.AlbertForQuestionAnswering(config),
78+
data_gen_fn=data_gen_for_qa,
79+
output_transform_fn=output_transform_fn,
80+
model_attribute=ModelAttribute(has_control_flow=True))
81+
model_zoo.register(name='transformers_albert_for_multiple_choice',
82+
model_fn=lambda: transformers.AlbertForMultipleChoice(config),
83+
data_gen_fn=data_gen_for_mcq,
84+
output_transform_fn=output_transform_fn,
85+
model_attribute=ModelAttribute(has_control_flow=True))
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
import transformers
3+
4+
from ..registry import ModelAttribute, model_zoo
5+
6+
# ===============================
7+
# Register single-sentence BERT
8+
# ===============================
9+
BATCH_SIZE = 2
10+
SEQ_LENGTH = 16
11+
12+
13+
def data_gen_fn():
14+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
15+
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
16+
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
17+
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
18+
19+
20+
output_transform_fn = lambda x: x
21+
22+
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
23+
24+
# register the BERT variants
25+
model_zoo.register(name='transformers_bert',
26+
model_fn=lambda: transformers.BertModel(config),
27+
data_gen_fn=data_gen_fn,
28+
output_transform_fn=output_transform_fn,
29+
model_attribute=ModelAttribute(has_control_flow=True))
30+
model_zoo.register(name='transformers_bert_for_pretraining',
31+
model_fn=lambda: transformers.BertForPreTraining(config),
32+
data_gen_fn=data_gen_fn,
33+
output_transform_fn=output_transform_fn,
34+
model_attribute=ModelAttribute(has_control_flow=True))
35+
model_zoo.register(name='transformers_bert_lm_head_model',
36+
model_fn=lambda: transformers.BertLMHeadModel(config),
37+
data_gen_fn=data_gen_fn,
38+
output_transform_fn=output_transform_fn,
39+
model_attribute=ModelAttribute(has_control_flow=True))
40+
model_zoo.register(name='transformers_bert_for_masked_lm',
41+
model_fn=lambda: transformers.BertForMaskedLM(config),
42+
data_gen_fn=data_gen_fn,
43+
output_transform_fn=output_transform_fn,
44+
model_attribute=ModelAttribute(has_control_flow=True))
45+
model_zoo.register(name='transformers_bert_for_sequence_classification',
46+
model_fn=lambda: transformers.BertForSequenceClassification(config),
47+
data_gen_fn=data_gen_fn,
48+
output_transform_fn=output_transform_fn,
49+
model_attribute=ModelAttribute(has_control_flow=True))
50+
model_zoo.register(name='transformers_bert_for_token_classification',
51+
model_fn=lambda: transformers.BertForTokenClassification(config),
52+
data_gen_fn=data_gen_fn,
53+
output_transform_fn=output_transform_fn,
54+
model_attribute=ModelAttribute(has_control_flow=True))
55+
56+
57+
# ===============================
58+
# Register multi-sentence BERT
59+
# ===============================
60+
def data_gen_for_next_sentence():
61+
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
62+
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
63+
next_sentence = "The sky is blue due to the shorter wavelength of blue light."
64+
encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
65+
return encoding
66+
67+
68+
def data_gen_for_mcq():
69+
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
70+
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
71+
choice0 = "It is eaten with a fork and a knife."
72+
choice1 = "It is eaten while held in the hand."
73+
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
74+
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
75+
return encoding
76+
77+
78+
# register the following models
79+
model_zoo.register(name='transformers_bert_for_next_sentence',
80+
model_fn=lambda: transformers.BertForNextSentencePrediction(config),
81+
data_gen_fn=data_gen_for_next_sentence,
82+
output_transform_fn=output_transform_fn,
83+
model_attribute=ModelAttribute(has_control_flow=True))
84+
model_zoo.register(name='transformers_bert_for_mcq',
85+
model_fn=lambda: transformers.BertForMultipleChoice(config),
86+
data_gen_fn=data_gen_for_mcq,
87+
output_transform_fn=output_transform_fn,
88+
model_attribute=ModelAttribute(has_control_flow=True))
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import torch
2+
import transformers
3+
4+
from ..registry import ModelAttribute, model_zoo
5+
6+
# ===============================
7+
# Register single-sentence GPT
8+
# ===============================
9+
BATCH_SIZE = 2
10+
SEQ_LENGTH = 16
11+
12+
13+
def data_gen():
14+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
15+
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
16+
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
17+
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
18+
19+
20+
output_transform_fn = lambda x: x
21+
22+
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
23+
24+
# register the following models
25+
model_zoo.register(name='transformers_gpt',
26+
model_fn=lambda: transformers.GPT2Model(config),
27+
data_gen_fn=data_gen,
28+
output_transform_fn=output_transform_fn,
29+
model_attribute=ModelAttribute(has_control_flow=True))
30+
model_zoo.register(name='transformers_gpt_lm',
31+
model_fn=lambda: transformers.GPT2LMHeadModel(config),
32+
data_gen_fn=data_gen,
33+
output_transform_fn=output_transform_fn,
34+
model_attribute=ModelAttribute(has_control_flow=True))
35+
model_zoo.register(name='transformers_gpt_double_heads',
36+
model_fn=lambda: transformers.GPT2DoubleHeadsModel(config),
37+
data_gen_fn=data_gen,
38+
output_transform_fn=output_transform_fn,
39+
model_attribute=ModelAttribute(has_control_flow=True))
40+
model_zoo.register(name='transformers_gpt_for_token_classification',
41+
model_fn=lambda: transformers.GPT2ForTokenClassification(config),
42+
data_gen_fn=data_gen,
43+
output_transform_fn=output_transform_fn,
44+
model_attribute=ModelAttribute(has_control_flow=True))
45+
model_zoo.register(name='transformers_gpt_for_sequence_classification',
46+
model_fn=lambda: transformers.GPT2ForSequenceClassification(config),
47+
data_gen_fn=data_gen,
48+
output_transform_fn=output_transform_fn,
49+
model_attribute=ModelAttribute(has_control_flow=True))
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import transformers
3+
4+
from ..registry import ModelAttribute, model_zoo
5+
6+
# ===============================
7+
# Register single-sentence OPT
8+
# ===============================
9+
BATCH_SIZE = 2
10+
SEQ_LENGTH = 16
11+
12+
13+
def data_gen():
14+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
15+
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
16+
return dict(input_ids=input_ids, attention_mask=attention_mask)
17+
18+
19+
output_transform_fn = lambda x: x
20+
21+
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
22+
23+
# register the following models
24+
# transformers.OPTModel,
25+
# transformers.OPTForCausalLM,
26+
model_zoo.register(name='transformers_opt',
27+
model_fn=lambda: transformers.OPTModel(config),
28+
data_gen_fn=data_gen,
29+
output_transform_fn=output_transform_fn,
30+
model_attribute=ModelAttribute(has_control_flow=True))
31+
model_zoo.register(name='transformers_opt_for_causal_lm',
32+
model_fn=lambda: transformers.OPTForCausalLM(config),
33+
data_gen_fn=data_gen,
34+
output_transform_fn=output_transform_fn,
35+
model_attribute=ModelAttribute(has_control_flow=True))
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import torch
2+
import transformers
3+
4+
from ..registry import ModelAttribute, model_zoo
5+
6+
# ===============================
7+
# Register single-sentence T5
8+
# ===============================
9+
BATCH_SIZE = 2
10+
SEQ_LENGTH = 16
11+
12+
13+
def data_gen():
14+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
15+
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
16+
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
17+
18+
19+
def data_gen_for_encoder_only():
20+
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
21+
return dict(input_ids=input_ids)
22+
23+
24+
output_transform_fn = lambda x: x
25+
26+
config = transformers.T5Config(d_model=128, num_layers=2)
27+
28+
# register the following models
29+
# transformers.T5Model,
30+
# transformers.T5ForConditionalGeneration,
31+
# transformers.T5EncoderModel,
32+
model_zoo.register(name='transformers_t5',
33+
model_fn=lambda: transformers.T5Model(config),
34+
data_gen_fn=data_gen,
35+
output_transform_fn=output_transform_fn,
36+
model_attribute=ModelAttribute(has_control_flow=True))
37+
model_zoo.register(name='transformers_t5_for_conditional_generation',
38+
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
39+
data_gen_fn=data_gen,
40+
output_transform_fn=output_transform_fn,
41+
model_attribute=ModelAttribute(has_control_flow=True))
42+
model_zoo.register(name='transformers_t5_encoder_model',
43+
model_fn=lambda: transformers.T5EncoderModel(config),
44+
data_gen_fn=data_gen_for_encoder_only,
45+
output_transform_fn=output_transform_fn,
46+
model_attribute=ModelAttribute(has_control_flow=True))
Lines changed: 8 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,18 @@
1-
import pytest
2-
import torch
3-
import transformers
41
from hf_tracer_utils import trace_model_and_compare_output
52

3+
from tests.kit.model_zoo import model_zoo
4+
65
BATCH_SIZE = 2
76
SEQ_LENGTH = 16
87

98

10-
def test_single_sentence_albert():
11-
MODEL_LIST = [
12-
transformers.AlbertModel,
13-
transformers.AlbertForPreTraining,
14-
transformers.AlbertForMaskedLM,
15-
transformers.AlbertForSequenceClassification,
16-
transformers.AlbertForTokenClassification,
17-
]
18-
19-
config = transformers.AlbertConfig(embedding_size=128,
20-
hidden_size=128,
21-
num_hidden_layers=2,
22-
num_attention_heads=4,
23-
intermediate_size=256)
24-
25-
def data_gen():
26-
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
27-
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
28-
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
29-
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
30-
return meta_args
31-
32-
for model_cls in MODEL_LIST:
33-
model = model_cls(config=config)
34-
trace_model_and_compare_output(model, data_gen)
35-
36-
37-
def test_multi_sentence_albert():
38-
config = transformers.AlbertConfig(hidden_size=128,
39-
num_hidden_layers=2,
40-
num_attention_heads=4,
41-
intermediate_size=256)
42-
tokenizer = transformers.BertTokenizer.from_pretrained("bert-base-uncased")
43-
44-
def data_gen_for_qa():
45-
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
46-
inputs = tokenizer(question, text, return_tensors="pt")
47-
return inputs
48-
49-
model = transformers.AlbertForQuestionAnswering(config)
50-
trace_model_and_compare_output(model, data_gen_for_qa)
51-
52-
def data_gen_for_mcq():
53-
prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
54-
choice0 = "It is eaten with a fork and a knife."
55-
choice1 = "It is eaten while held in the hand."
56-
encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
57-
encoding = {k: v.unsqueeze(0) for k, v in encoding.items()}
58-
return encoding
9+
def test_albert():
10+
sub_registry = model_zoo.get_sub_registry('transformers_albert')
5911

60-
model = transformers.AlbertForMultipleChoice(config)
61-
trace_model_and_compare_output(model, data_gen_for_mcq)
12+
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items():
13+
model = model_fn()
14+
trace_model_and_compare_output(model, data_gen_fn)
6215

6316

6417
if __name__ == '__main__':
65-
test_single_sentence_albert()
66-
test_multi_sentence_albert()
18+
test_albert()

0 commit comments

Comments
 (0)