forked from codertimo/BERT-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_bert.py
123 lines (103 loc) · 4.76 KB
/
test_bert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
Skip to content
Search or jump to…
Pull requests
Issues
Codespaces
Marketplace
Explore
@wanghesong2019
songyingxin
/
BERT-pytorch
Public
forked from codertimo/BERT-pytorch
Fork your own copy of songyingxin/BERT-pytorch
Code
Pull requests
Actions
Projects
Security
Insights
BERT-pytorch/test_bert.py /
@songyingxin
songyingxin note
Latest commit e32e2ad on Jul 30, 2019
History
1 contributor
94 lines (76 sloc) 4.35 KB
import argparse
from torch.utils.data import DataLoader
from bert_pytorch.model import BERT
from bert_pytorch.trainer import BERTTrainer
from bert_pytorch.dataset import BERTDataset, WordVocab
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--train_dataset", required=True,
type=str, help="train dataset for train bert")
parser.add_argument("-t", "--test_dataset", type=str,
default=None, help="test set for evaluate train set")
parser.add_argument("-v", "--vocab_path", required=True,
type=str, help="built vocab model path with bert-vocab")
parser.add_argument("-o", "--output_path", required=True,
type=str, help="ex)output/bert.model")
parser.add_argument("-hs", "--hidden", type=int,
default=256, help="hidden size of transformer model")
parser.add_argument("-l", "--layers", type=int,
default=8, help="number of layers")
parser.add_argument("-a", "--attn_heads", type=int,
default=8, help="number of attention heads")
parser.add_argument("-s", "--seq_len", type=int,
default=20, help="maximum sequence len")
parser.add_argument("-b", "--batch_size", type=int,
default=64, help="number of batch_size")
parser.add_argument("-e", "--epochs", type=int,
default=10, help="number of epochs")
parser.add_argument("-w", "--num_workers", type=int,
default=5, help="dataloader worker size")
parser.add_argument("--with_cuda", type=bool, default=True,
help="training with CUDA: true, or false")
parser.add_argument("--log_freq", type=int, default=10,
help="printing loss every n iter: setting n")
parser.add_argument("--corpus_lines", type=int,
default=None, help="total number of lines in corpus")
parser.add_argument("--cuda_devices", type=int, nargs='+',
default=None, help="CUDA device ids")
parser.add_argument("--on_memory", type=bool, default=True,
help="Loading on memory: true or false")
parser.add_argument("--lr", type=float, default=1e-3,
help="learning rate of adam")
parser.add_argument("--adam_weight_decay", type=float,
default=0.01, help="weight_decay of adam")
parser.add_argument("--adam_beta1", type=float,
default=0.9, help="adam first beta value")
parser.add_argument("--adam_beta2", type=float,
default=0.999, help="adam first beta value")
args = parser.parse_args()
print("Loading Vocab", args.vocab_path)
vocab = WordVocab.load_vocab(args.vocab_path)
print("Vocab Size: ", len(vocab))
print("Loading Train Dataset", args.train_dataset)
train_dataset = BERTDataset(args.train_dataset, vocab, seq_len=args.seq_len,
corpus_lines=args.corpus_lines, on_memory=args.on_memory)
print("Loading Test Dataset", args.test_dataset)
test_dataset = BERTDataset(args.test_dataset, vocab, seq_len=args.seq_len, on_memory=args.on_memory) \
if args.test_dataset is not None else None
print("Creating Dataloader")
train_data_loader = DataLoader(
train_dataset, batch_size=args.batch_size, num_workers=args.num_workers)
test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=args.num_workers) \
if test_dataset is not None else None
print("Building BERT model")
bert = BERT(len(vocab), hidden=args.hidden,
n_layers=args.layers, attn_heads=args.attn_heads)
print("Creating BERT Trainer")
trainer = BERTTrainer(bert, len(vocab), train_dataloader=train_data_loader, test_dataloader=test_data_loader,
lr=args.lr, betas=(
args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
with_cuda=args.with_cuda, cuda_devices=args.cuda_devices, log_freq=args.log_freq)
print("Training Start")
for epoch in range(args.epochs):
trainer.train(epoch)
trainer.save(epoch, args.output_path)
if test_data_loader is not None:
trainer.test(epoch)