-
Notifications
You must be signed in to change notification settings - Fork 0
/
trainer.py
124 lines (115 loc) · 5.81 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os
from datetime import datetime
from transformers import AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
import math
import pickle
import sys
import utils
import torch
from datasets import load_dataset
import datasets
from transformers import DataCollatorForLanguageModeling
from transformers import AutoModelForMaskedLM
from transformers import AutoTokenizer
def train(model_name, task, dataset_name, num_epochs, column_name, log_dir):
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
dataset = load_dataset(*dataset_name.split(','))
if 'validation' not in dataset:
if 'test' not in dataset:
train_test_split = dataset['train'].train_test_split(test_size=0.1)
dataset = datasets.DatasetDict({
'train': train_test_split['train'],
'validation': train_test_split['test']})
else:
dataset = datasets.DatasetDict({
'train': dataset['train'],
'validation': dataset['test']})
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
tokenized_datasets = dataset.map(lambda examples: tokenizer(examples[column_name]),
batched=True, num_proc=2, remove_columns=[column_name])
lm_datasets = tokenized_datasets.map(
utils.group_texts,
batched=True,
batch_size=1000,
num_proc=2
)
training_args = TrainingArguments(
"test-clm",
evaluation_strategy="epoch",
learning_rate=2e-5,
weight_decay=0.01,
num_train_epochs=num_epochs,
logging_dir=log_dir,
logging_strategy='epoch'
)
if task == 'causal':
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets["validation"]
)
elif task == 'MLM':
model = AutoModelForMaskedLM.from_pretrained(model_name)
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=lm_datasets["train"],
eval_dataset=lm_datasets["validation"],
data_collator=data_collator
)
metrics = trainer.train()
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")
return model, tokenizer, metrics, eval_results
def main(raw_args):
parser = argparse.ArgumentParser(description='Train a model and pickle it!\n'
'List of MLM models: https://huggingface.co/models?filter=masked-lm\n'
'List of causal models: https://huggingface.co/models?filter=causal-lm\n'
'List of datasets: https://huggingface.co/datasets')
parser.add_argument(
'--model_name', '-m', type=str, nargs='?', help='Type of pre trained model', default='bert-base-uncased')
parser.add_argument(
'--task', '-t', type=str, nargs='?', help='Type of task. either MLM or causal', default='MLM')
parser.add_argument(
'--dataset', '-d', type=str, nargs='?', help='Dataset to use. can be more than one word for *args,\n'
'for example: \'wikitext,wikitext-2-raw-v1\' will be parsed as\n'
'[\'wikitext\',\'wikitext-2-raw-v1\']')
parser.add_argument('--epochs', '-e', type=int, nargs='?', help='Number of training epochs', default=3)
parser.add_argument('--save_dir', '-s', type=str, nargs='?', help='Path of dir to save model in')
parser.add_argument('--log_dir', '-l', type=str, nargs='?', help='Path of dir to save train logs in')
parser.add_argument('--column_name', '-cn', type=str, nargs='?', help='The name of the text column in the dataset',
default='text')
args = parser.parse_args(raw_args)
assert os.path.isdir(args.save_dir)
assert os.path.isdir(args.log_dir)
timestamp = datetime.now().strftime('%y%m%d%H%m')
model_save_path = os.path.join(args.save_dir, '{0}_{1}_{2}.pkl'.format(args.model_name.replace('/','-'),
args.dataset.split(',')[0], timestamp))
tokenizer_save_path = os.path.join(args.save_dir, '{0}_{1}_{2}.pkl'.format(args.model_name, 'tokenizer', timestamp))
eval_save_path = os.path.join(args.log_dir, '{0}_{1}_{2}.pkl'.format(args.model_name.replace('/','-'),
'eval', timestamp))
metrics_save_path = os.path.join(args.log_dir, '{0}_{1}_{2}.pkl'.format(args.model_name.replace('/','-'),
'metrics', timestamp))
model, tokenizer, metrics, eval_results = train(args.model_name, args.task, args.dataset,
args.epochs, args.column_name, args.log_dir)
torch.save(model, model_save_path)
print('Saved model to {}'.format(model_save_path))
with open(tokenizer_save_path, 'wb') as f:
pickle.dump(tokenizer, f)
print('Saved tokenizer to {}'.format(tokenizer_save_path))
with open(eval_save_path, 'wb') as f:
pickle.dump(eval_results, f)
print('Saved eval to {}'.format(eval_save_path))
with open(metrics_save_path, 'wb') as f:
pickle.dump(metrics, f)
print('Saved metrics to {}'.format(metrics_save_path))
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))