-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_ar.py
144 lines (125 loc) · 4.58 KB
/
train_ar.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import contextlib
import fire
import functools
import mup
import numpy as np
import lib.ddp
import lib.ema
import lib.datasets
import lib.models
import lib.ops
import lib.utils
import os
import time
import torch
import torch.nn.functional as F
import tqdm
from torch import nn, optim, autograd
from torch.nn.parallel import DistributedDataParallel as DDP
def main(**args):
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
args = lib.utils.AttributeDict(args)
args.setdefault('batch_size', 64)
args.setdefault('dataset', 'openwebtext2')
args.setdefault('grad_accum_steps', 1)
args.setdefault('hook_freq', 10000)
args.setdefault('lr', 8e-3)
args.setdefault('lr_warmup_steps', 1000)
args.setdefault('lr_decay', True)
args.setdefault('print_freq', 1000)
args.setdefault('save_weights', False)
args.setdefault('steps', 104000)
args.setdefault('weights_path', None)
args.setdefault('dim', 768)
args.setdefault('n_blocks', 12)
args.setdefault('n_heads', 12)
args.setdefault('seq_len', 256)
args.setdefault('val_steps', 1000)
args.setdefault('val_batch_size', 64)
args.setdefault('weight_decay', 4e-5)
args.setdefault('ema', 0.)
args.setdefault('tie_embeddings', False)
lib.utils.print_args(args)
dataset = lib.datasets.REGISTRY[args.dataset](
args.batch_size, args.val_batch_size, args.seq_len
)
(train_iterator,val_iterator,test_iterator), (word2idx, idx2word) = dataset
seq_len = args.seq_len
vocab_size = len(word2idx)
print(f'seq_len: {seq_len}, vocab_size: {vocab_size}')
model = lib.models.AutoregressiveModel(args.dim, args.n_blocks, args.n_heads, vocab_size, args.tie_embeddings)
base_model = lib.models.AutoregressiveModel(256, args.n_blocks, 4, vocab_size, args.tie_embeddings)
delta_model = lib.models.AutoregressiveModel(128, args.n_blocks, 2, vocab_size, args.tie_embeddings)
mup.set_base_shapes(model, base_model, delta=delta_model)
model = model.cuda()
lib.utils.print_model(model)
if args.weights_path is not None:
model.load_state_dict(torch.load(
os.path.join(args.weights_path, 'model.pt')
))
ddp_model = DDP(model)
ema = lib.ema.EMA(model, args.ema)
def forward(*_):
X = next(train_iterator).cuda().long()
logits = ddp_model(X)
loss = lib.ops.cross_entropy(logits, X).mean()
return loss
def compute_nll(data_iterator, steps, eval_seq_len=seq_len):
with torch.no_grad():
with ema.enabled():
total_nll = 0.
total_tokens = 0
for i, X in enumerate(data_iterator):
X = X.cuda()[:,:eval_seq_len]
logits = ddp_model(X)
loss = lib.ops.cross_entropy(logits, X).mean()
total_nll += loss.item() * X.numel()
total_tokens += X.numel()
if i == steps:
break
return total_nll / total_tokens
all_val_nlls = []
def hook(step):
ema.step()
if step % args.hook_freq == args.hook_freq - 1:
for eval_seq_len in [256, 1024]:
val_nll = compute_nll(val_iterator, args.val_steps, eval_seq_len)
print(f'NLL (val, seq len {eval_seq_len}): {val_nll}')
if eval_seq_len == seq_len:
all_val_nlls.append(val_nll)
if (lib.ddp.rank() == 0) and args.save_weights:
torch.save(model.state_dict(), 'model.pt')
def impl(param_groups, **kwargs):
assert('weight_decay' not in kwargs)
for param_group in param_groups:
param_group['weight_decay'] = (
args.weight_decay / (param_group['lr'] + 1e-16)
)
return optim.AdamW(param_groups, **kwargs)
opt = mup.MuAdam(
model.parameters(),
impl=impl,
lr=args.lr,
betas=(0.9, 0.99)
)
lib.utils.train_loop(
forward,
opt,
args.steps,
hook=hook,
print_freq=args.print_freq,
lr_warmup_steps=args.lr_warmup_steps,
lr_decay=args.lr_decay,
amp_grad_scaler=False,
grad_accum_steps=args.grad_accum_steps,
ddp_models=[ddp_model],
clip_params=[
param for param in model.parameters()
]
)
final_val_nll = compute_nll(val_iterator, 3000)
print('Final val NLL:', final_val_nll)
return all_val_nlls, final_val_nll
if __name__ == '__main__':
fire.Fire(lib.ddp.wrap_main(main))