Skip to content

Commit 5276712

Browse files
committed
20200824
1 parent 1d95ba2 commit 5276712

File tree

7 files changed

+1828
-0
lines changed

7 files changed

+1828
-0
lines changed

mini_GPT/mingpt/__init__.py

Whitespace-only changes.

mini_GPT/mingpt/model.py

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
GPT model:
3+
- the initial stem consists of a combination of token encoding and a positional encoding
4+
- the meat of it is a uniform sequence of Transformer blocks
5+
- each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block
6+
- all blocks feed into a central residual pathway similar to resnets
7+
- the final decoder is a linear projection into a vanilla Softmax classifier
8+
"""
9+
10+
import math
11+
import logging
12+
13+
import torch
14+
import torch.nn as nn
15+
from torch.nn import functional as F
16+
17+
logger = logging.getLogger(__name__)
18+
19+
class GPTConfig:
20+
""" base GPT config, params common to all GPT versions """
21+
embd_pdrop = 0.1
22+
resid_pdrop = 0.1
23+
attn_pdrop = 0.1
24+
25+
def __init__(self, vocab_size, block_size, **kwargs):
26+
self.vocab_size = vocab_size
27+
self.block_size = block_size
28+
for k,v in kwargs.items():
29+
setattr(self, k, v)
30+
31+
class GPT1Config(GPTConfig):
32+
""" GPT-1 like network roughly 125M params """
33+
n_layer = 12
34+
n_head = 12
35+
n_embd = 768
36+
37+
class CausalSelfAttention(nn.Module):
38+
"""
39+
A vanilla multi-head masked self-attention layer with a projection at the end.
40+
It is possible to use torch.nn.MultiheadAttention here but I am including an
41+
explicit implementation here to show that there is nothing too scary here.
42+
"""
43+
44+
def __init__(self, config):
45+
super().__init__()
46+
assert config.n_embd % config.n_head == 0
47+
# key, query, value projections for all heads
48+
self.key = nn.Linear(config.n_embd, config.n_embd)
49+
self.query = nn.Linear(config.n_embd, config.n_embd)
50+
self.value = nn.Linear(config.n_embd, config.n_embd)
51+
# regularization
52+
self.attn_drop = nn.Dropout(config.attn_pdrop)
53+
self.resid_drop = nn.Dropout(config.resid_pdrop)
54+
# output projection
55+
self.proj = nn.Linear(config.n_embd, config.n_embd)
56+
# causal mask to ensure that attention is only applied to the left in the input sequence
57+
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
58+
.view(1, 1, config.block_size, config.block_size))
59+
self.n_head = config.n_head
60+
61+
def forward(self, x, layer_past=None):
62+
B, T, C = x.size()
63+
64+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
65+
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
66+
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
67+
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
68+
69+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
70+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
71+
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
72+
att = F.softmax(att, dim=-1)
73+
att = self.attn_drop(att)
74+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
75+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
76+
77+
# output projection
78+
y = self.resid_drop(self.proj(y))
79+
return y
80+
81+
class Block(nn.Module):
82+
""" an unassuming Transformer block """
83+
84+
def __init__(self, config):
85+
super().__init__()
86+
self.ln1 = nn.LayerNorm(config.n_embd)
87+
self.ln2 = nn.LayerNorm(config.n_embd)
88+
self.attn = CausalSelfAttention(config)
89+
self.mlp = nn.Sequential(
90+
nn.Linear(config.n_embd, 4 * config.n_embd),
91+
nn.GELU(),
92+
nn.Linear(4 * config.n_embd, config.n_embd),
93+
nn.Dropout(config.resid_pdrop),
94+
)
95+
96+
def forward(self, x):
97+
x = x + self.attn(self.ln1(x))
98+
x = x + self.mlp(self.ln2(x))
99+
return x
100+
101+
class GPT(nn.Module):
102+
""" the full GPT language model, with a context size of block_size """
103+
104+
def __init__(self, config):
105+
super().__init__()
106+
107+
# input embedding stem
108+
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
109+
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
110+
self.drop = nn.Dropout(config.embd_pdrop)
111+
# transformer
112+
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
113+
# decoder head
114+
self.ln_f = nn.LayerNorm(config.n_embd)
115+
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
116+
117+
self.block_size = config.block_size
118+
self.apply(self._init_weights)
119+
120+
logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters()))
121+
122+
def get_block_size(self):
123+
return self.block_size
124+
125+
def _init_weights(self, module):
126+
if isinstance(module, (nn.Linear, nn.Embedding)):
127+
module.weight.data.normal_(mean=0.0, std=0.02)
128+
if isinstance(module, nn.Linear) and module.bias is not None:
129+
module.bias.data.zero_()
130+
elif isinstance(module, nn.LayerNorm):
131+
module.bias.data.zero_()
132+
module.weight.data.fill_(1.0)
133+
134+
def configure_optimizers(self, train_config):
135+
"""
136+
This long function is unfortunately doing something very simple and is being very defensive:
137+
We are separating out all parameters of the model into two buckets: those that will experience
138+
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
139+
We are then returning the PyTorch optimizer object.
140+
"""
141+
142+
# separate out all parameters to those that will and won't experience regularizing weight decay
143+
decay = set()
144+
no_decay = set()
145+
whitelist_weight_modules = (torch.nn.Linear, )
146+
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
147+
for mn, m in self.named_modules():
148+
for pn, p in m.named_parameters():
149+
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
150+
151+
if pn.endswith('bias'):
152+
# all biases will not be decayed
153+
no_decay.add(fpn)
154+
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
155+
# weights of whitelist modules will be weight decayed
156+
decay.add(fpn)
157+
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
158+
# weights of blacklist modules will NOT be weight decayed
159+
no_decay.add(fpn)
160+
161+
# special case the position embedding parameter in the root GPT module as not decayed
162+
no_decay.add('pos_emb')
163+
164+
# validate that we considered every parameter
165+
param_dict = {pn: p for pn, p in self.named_parameters()}
166+
inter_params = decay & no_decay
167+
union_params = decay | no_decay
168+
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
169+
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
170+
% (str(param_dict.keys() - union_params), )
171+
172+
# create the pytorch optimizer object
173+
optim_groups = [
174+
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
175+
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
176+
]
177+
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
178+
return optimizer
179+
180+
def forward(self, idx, targets=None):
181+
b, t = idx.size()
182+
assert t <= self.block_size, "Cannot forward, model block size is exhausted."
183+
184+
# forward the GPT model
185+
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
186+
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
187+
x = self.drop(token_embeddings + position_embeddings)
188+
x = self.blocks(x)
189+
x = self.ln_f(x)
190+
logits = self.head(x)
191+
192+
# if we are given some desired targets also calculate the loss
193+
loss = None
194+
if targets is not None:
195+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
196+
197+
return logits, loss

mini_GPT/mingpt/trainer.py

+128
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""
2+
Simple training loop; Boilerplate that could apply to any arbitrary neural network,
3+
so nothing in this file really has anything to do with GPT specifically.
4+
"""
5+
6+
import math
7+
import logging
8+
9+
from tqdm import tqdm
10+
import numpy as np
11+
12+
import torch
13+
import torch.optim as optim
14+
from torch.optim.lr_scheduler import LambdaLR
15+
from torch.utils.data.dataloader import DataLoader
16+
17+
logger = logging.getLogger(__name__)
18+
19+
class TrainerConfig:
20+
# optimization parameters
21+
max_epochs = 10
22+
batch_size = 64
23+
learning_rate = 3e-4
24+
betas = (0.9, 0.95)
25+
grad_norm_clip = 1.0
26+
weight_decay = 0.1 # only applied on matmul weights
27+
# learning rate decay params: linear warmup followed by cosine decay to 10% of original
28+
lr_decay = False
29+
warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
30+
final_tokens = 260e9 # (at what point we reach 10% of original LR)
31+
# checkpoint settings
32+
ckpt_path = None
33+
num_workers = 0 # for DataLoader
34+
35+
def __init__(self, **kwargs):
36+
for k,v in kwargs.items():
37+
setattr(self, k, v)
38+
39+
class Trainer:
40+
41+
def __init__(self, model, train_dataset, test_dataset, config):
42+
self.model = model
43+
self.train_dataset = train_dataset
44+
self.test_dataset = test_dataset
45+
self.config = config
46+
47+
# take over whatever gpus are on the system
48+
self.device = 'cpu'
49+
if torch.cuda.is_available():
50+
self.device = torch.cuda.current_device()
51+
self.model = torch.nn.DataParallel(self.model).to(self.device)
52+
53+
def save_checkpoint(self):
54+
# DataParallel wrappers keep raw model object in .module attribute
55+
raw_model = self.model.module if hasattr(self.model, "module") else self.model
56+
logger.info("saving %s", self.config.ckpt_path)
57+
torch.save(raw_model.state_dict(), self.config.ckpt_path)
58+
59+
def train(self):
60+
model, config = self.model, self.config
61+
raw_model = model.module if hasattr(self.model, "module") else model
62+
optimizer = raw_model.configure_optimizers(config)
63+
64+
def run_epoch(split):
65+
is_train = split == 'train'
66+
model.train(is_train)
67+
data = self.train_dataset if is_train else self.test_dataset
68+
loader = DataLoader(data, batch_size=config.batch_size, num_workers=config.num_workers)
69+
70+
losses = []
71+
pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
72+
for it, (x, y) in pbar:
73+
74+
# place data on the correct device
75+
x = x.to(self.device)
76+
y = y.to(self.device)
77+
78+
# forward the model
79+
with torch.set_grad_enabled(is_train):
80+
logits, loss = model(x, y)
81+
loss = loss.mean() # collapse all losses if they are scattered on multiple gpus
82+
losses.append(loss.item())
83+
84+
if is_train:
85+
86+
# backprop and update the parameters
87+
model.zero_grad()
88+
loss.backward()
89+
torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
90+
optimizer.step()
91+
92+
# decay the learning rate based on our progress
93+
if config.lr_decay:
94+
self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
95+
if self.tokens < config.warmup_tokens:
96+
# linear warmup
97+
lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
98+
else:
99+
# cosine learning rate decay
100+
progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
101+
lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
102+
lr = config.learning_rate * lr_mult
103+
for param_group in optimizer.param_groups:
104+
param_group['lr'] = lr
105+
else:
106+
lr = config.learning_rate
107+
108+
# report progress
109+
pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")
110+
111+
if not is_train:
112+
test_loss = float(np.mean(losses))
113+
logger.info("test loss: %f", test_loss)
114+
return test_loss
115+
116+
best_loss = float('inf')
117+
self.tokens = 0 # counter used for learning rate decay
118+
for epoch in range(config.max_epochs):
119+
120+
run_epoch('train')
121+
if self.test_dataset is not None:
122+
test_loss = run_epoch('test')
123+
124+
# supports early stopping based on the test loss, or just save always if no test set is provided
125+
good_model = self.test_dataset is None or test_loss < best_loss
126+
if self.config.ckpt_path is not None and good_model:
127+
best_loss = test_loss
128+
self.save_checkpoint()

mini_GPT/mingpt/utils.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import random
2+
import numpy as np
3+
import torch
4+
import torch.nn as nn
5+
from torch.nn import functional as F
6+
7+
def set_seed(seed):
8+
random.seed(seed)
9+
np.random.seed(seed)
10+
torch.manual_seed(seed)
11+
torch.cuda.manual_seed_all(seed)
12+
13+
def top_k_logits(logits, k):
14+
v, ix = torch.topk(logits, k)
15+
out = logits.clone()
16+
out[out < v[:, [-1]]] = -float('Inf')
17+
return out
18+
19+
@torch.no_grad()
20+
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
21+
"""
22+
take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
23+
the sequence, feeding the predictions back into the model each time. Clearly the sampling
24+
has quadratic complexity unlike an RNN that is only linear, and has a finite context window
25+
of block_size, unlike an RNN that has an infinite context window.
26+
"""
27+
block_size = model.get_block_size()
28+
model.eval()
29+
for k in range(steps):
30+
x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
31+
logits, _ = model(x_cond)
32+
# pluck the logits at the final step and scale by temperature
33+
logits = logits[:, -1, :] / temperature
34+
# optionally crop probabilities to only the top k options
35+
if top_k is not None:
36+
logits = top_k_logits(logits, top_k)
37+
# apply softmax to convert to probabilities
38+
probs = F.softmax(logits, dim=-1)
39+
# sample from the distribution or take the most likely
40+
if sample:
41+
ix = torch.multinomial(probs, num_samples=1)
42+
else:
43+
_, ix = torch.topk(probs, k=1, dim=-1)
44+
# append to the sequence and continue
45+
x = torch.cat((x, ix), dim=1)
46+
47+
return x

0 commit comments

Comments
 (0)