I want to provide some tips from my experience implementing a paper. I'm going to cover my tips so far from implementing a dramatically scaled-down version of Llama for training TinyShakespeare. This post is heavily inspired by Karpathy's Makemore series, which I highly recommend.
I'm only going to loosely follow the layout of their paper; while the formatting and order of sections makes sense for publication, we're going to be implementing the paper. I'll also be skipping over some of the more obvious steps, like setting up a virtual environment and installing dependencies.
My approach for implementing papers is:
- Make all of the helper functions required to test your model quantitatively (data splits, training, plotting the loss).
- Before you even look at the paper, pick a small, simple, and fast model that you've done in the past. Then make a helper function to evaluate the model qualitatively.
- Start by picking apart different components of the paper, and then implementing them one-by-one, training and evaluating as you go.
- Use
.shape
religiously.assert
andplt.imshow
are also your friends. - Work out the results without matrix multiplication first, and then use the
torch
functions to make it efficient after. - Have a test to see that your layer is right. For example, the RoPE embeddings have a specific property that you can test for. For the Transformer, you can test that the attention is working by looking at the attention map.
- Test your layers on various batch, sequence, and embedding sizes. Even if it works for one size, it might not work for others, which will cause problems at inference time.
Llama is a transformer-based model for language modeling. Meta AI open-sourced Llama this summer, and it's gained a lot of attention (pun intended). When you're reading the introduction, they clearly indicate their goal: make a model that's cheaper for running inference, rather than optimizing training costs.
At this point, we'll just load our libraries and get started.
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
from matplotlib import pyplot as plt
import time
import pandas as pd
While in Llama they train on 1.4T tokens, our dataset TinyShakespeare, the collection of all of Shakespeare's works, is about 1M characters.
lines = open('./input.txt', 'r').read()
vocab = sorted(list(set(lines)))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}
print(lines[:30])
First Citizen:
Before we proce
They use the SentencePiece byte-pair encoding tokenizer, but we're going to just use a simple character-level tokenizer.
# simple tokenization by characters
def encode(s):
return [stoi[ch] for ch in s]
def decode(l):
return ''.join([itos[i] for i in l])
print('vocab size:', len(vocab))
decode(encode("hello"))
vocab size: 65
'hello'
Since our dataset is small enough, we don't need to worry about how we store it in memory etc.
First tip: I'm creating a config
object that stores some basic model params. It makes our code way more readable and removes constants and magic numbers from the code. I'm not going to use types, as I want to keep things flexible for now and be able to add more parameters later on.
MASTER_CONFIG = {
"vocab_size": len(vocab),
}
dataset = torch.tensor(encode(lines), dtype=torch.int8)
dataset.shape
torch.Size([1115394])
Let's create a method to generate our training data and labels for batches. We'll use the same method for validation and test data. Note that I like to test my functions in the same block that I define them, just to make sure they work as expected before moving on.
def get_batches(data, split, batch_size, context_window, config=MASTER_CONFIG):
train = data[:int(.8 * len(data))]
val = data[int(.8 * len(data)): int(.9 * len(data))]
test = data[int(.9 * len(data)):]
batch_data = train
if split == 'val':
batch_data = val
if split == 'test':
batch_data = test
# pick random starting points
ix = torch.randint(0, batch_data.size(0) - context_window - 1, (batch_size,))
x = torch.stack([batch_data[i:i+context_window] for i in ix]).long()
y = torch.stack([batch_data[i+1:i+context_window+1] for i in ix]).long()
return x, y
MASTER_CONFIG.update({
'batch_size': 32,
'context_window': 16
})
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
[(decode(xs[i].tolist()), decode(ys[i].tolist())) for i in range(len(xs))]
[('t us sup betimes', ' us sup betimes,'),
('IO:\nRight.\n\nROME', 'O:\nRight.\n\nROMEO'),
('Nurse:\nO, she sa', 'urse:\nO, she say'),
(" seem'd to know,", "seem'd to know,\n"),
(' let her brother', 'let her brother '),
('ood,\nEven with s', 'od,\nEven with su'),
('en\nWhisper the s', 'n\nWhisper the sp'),
(', fly! for all y', ' fly! for all yo'),
(" Saint Peter's C", "Saint Peter's Ch"),
(', but that\nWhich', ' but that\nWhich '),
('uld as willingly', 'ld as willingly '),
('y brother,\nIs pr', ' brother,\nIs pri'),
(' you ready your ', 'you ready your s'),
('rth the audience', 'th the audience '),
('\n\nQUEEN ELIZABET', '\nQUEEN ELIZABETH'),
('ection,\nwhich ca', 'ction,\nwhich can'),
('is wisdom hastes', 's wisdom hastes '),
(' and quinces in ', 'and quinces in t'),
('nt death.\n\nSICIN', 't death.\n\nSICINI'),
("y she's mad.\n\nBR", " she's mad.\n\nBRU"),
('eware of him;\nSi', 'ware of him;\nSin'),
('s\nAnd make pursu', '\nAnd make pursui'),
('r and be slain; ', ' and be slain; n'),
(' I, with grief a', 'I, with grief an'),
('?\n\nSecond Keeper', '\n\nSecond Keeper:'),
('\nNow, Thomas Mow', 'Now, Thomas Mowb'),
('or this once, ye', 'r this once, yea'),
("l 'tis just.\n\nLU", " 'tis just.\n\nLUC"),
('es like a lamb. ', 's like a lamb. Y'),
('t night, I warra', ' night, I warran'),
('y tears would wa', ' tears would was'),
('\n\nANGELO:\nWell, ', '\nANGELO:\nWell, l')]
What's interesting about implementing papers is that there are two aspects to the model working: compilation (do your tensors all match up from layer to layer), and training (does the loss go down). Figuring out how to ensure that each of your compoenents is working is key to developing your model in a predictable, engineering-minded way.
That's why we're also going to define the method for how we're going to evaluate the model. We want to do this before we even define the model, because we want to be able to use it to evaluate the model as we're training it.
@torch.no_grad() # don't compute gradients for this function
def evaluate_loss(model, config=MASTER_CONFIG):
out = {}
model.eval()
for split in ["train", "val"]:
losses = []
for _ in range(10):
xb, yb = get_batches(dataset, split, config['batch_size'], config['context_window'])
_, loss = model(xb, yb)
losses.append(loss.item())
out[split] = np.mean(losses)
model.train()
return out
Here's a basic feed-forward neural network with embeddings. It's the base model we're going to start with, and then swap out parts of it as we go along until we eventually end up with the model as described in Llama.
class SimpleBrokenModel(nn.Module):
def __init__(self, config=MASTER_CONFIG):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
self.linear = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
nn.ReLU(),
nn.Linear(config['d_model'], config['vocab_size']),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embedding(idx)
a = self.linear(x)
logits = F.softmax(a, dim=-1)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
else:
return logits
MASTER_CONFIG.update({
'd_model': 128,
})
model = SimpleBrokenModel(MASTER_CONFIG)
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = model(xs, ys)
model params: 33217
It's at this point that we have to start worrying about the shape of our tensors and making indices match. Check out this line of our model definition:
loss = F.cross_entropy(logits.view(-1, config['vocab_size']), targets.view(-1))
We have to reshape the logits
and targets
tensors so that their dimensions match when we compare. We do this with the view
method. The -1
argument means "infer this dimension from the others". So, in this case, we're saying "reshape logits
and targets
to have the same number of rows, and however many columns are needed to make that happen". This is a common pattern when you're working with batches of data.
Alright, let's train our SimpleBrokenModel
to make sure gradients flow. After we confirm that, we can swap out parts of it to match Llama, train again, and track our progress. It's at this point that I start keeping a log of my training runs, so that I can easily just go back to a previous run in the event that I mess something up.
MASTER_CONFIG.update({
'epochs': 1000,
'log_interval': 10
})
model = SimpleBrokenModel(MASTER_CONFIG)
optimizer = torch.optim.Adam(
model.parameters(),
)
def train(model, optimizer, scheduler=None, config=MASTER_CONFIG, print_logs=False):
losses = []
start_time = time.time()
for epoch in range(config['epochs']):
optimizer.zero_grad()
xs, ys = get_batches(dataset, 'train', config['batch_size'], config['context_window'])
logits, loss = model(xs, targets=ys)
loss.backward()
optimizer.step()
if scheduler:
scheduler.step()
if epoch % config['log_interval'] == 0:
batch_time = time.time() - start_time
x = evaluate_loss(model)
losses += [x]
if print_logs:
print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (config['epochs'] - epoch)/config['log_interval'] :.3f}")
start_time = time.time()
if scheduler:
print("lr: ", scheduler.get_lr())
print("validation loss: ", losses[-1]['val'])
return pd.DataFrame(losses).plot()
train(model, optimizer)
model params: 33217
validation loss: 3.9457355260849
<Axes: >
Notice how we get a training curve that goes down, but barely by anything. How do we know it's barely training? We have to use first principles. The cross-entropy loss before training is 4.17, and after 1000 epochs is 3.93. How can we make sense of it intuitively?
Cross-entropy in this context is referring to how likely we are to pick the wrong word. So here,
where
To get an intuition for how the loss relates to the model's performance, think about the model choosing among
Let's try to debug what's going on. Notice that in our model we're using a softmax layer on our logits, which is a function that takes a vector of numbers and squashes them into a probability distribution. But for using the built in F.cross_entropy
function, we need to pass in the unnormalized logits directly. So let's remove that from our model and try again.
class SimpleModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
self.linear = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
nn.ReLU(),
nn.Linear(config['d_model'], config['vocab_size']),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embedding(idx)
logits = self.linear(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
else:
return logits
model = SimpleModel(MASTER_CONFIG)
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = model(xs, ys)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer)
model params: 33217
validation loss: 2.5113263607025145
<Axes: >
Great, now our loss is
def generate(model, config=MASTER_CONFIG, max_new_tokens=30):
idx = torch.zeros(5, 1).long()
for _ in range(max_new_tokens):
# call the model
logits = model(idx[:, -config['context_window']:])
last_time_step_logits = logits[
:, -1, :
] # all the batches (1), last time step, all the logits
p = F.softmax(last_time_step_logits, dim=-1) # softmax to get probabilities
idx_next = torch.multinomial(
p, num_samples=1
) # sample from the distribution to get the next token
idx = torch.cat([idx, idx_next], dim=-1) # append to the sequence
return [decode(x) for x in idx.tolist()]
generate(model)
['\nWI in\nThed grtend\nA yod ys wit',
'\nY aroticunutser\nE oy mendomed ',
"\n\nRIf t fan f ses, k be wn'd mo",
'\nRu hiseedst den t wat onderyou',
"\nARaceps hond wr f\nI' fu kn be "]
It's not half bad, but also not half good. But now we have a working model that is training to a validation loss. So here we'll iterate on our model to make it closer to Llama.
Llama describes three architectural modifications to the original Transformer:
- RMSNorm for pre-normalization
- Rotary embeddings
- SwiGLU activation function
We're going to add each one, one at a time to our base model, and iterate.
In Vaswani 2017, the original transformer uses BatchNormalization. In Llama, the authors use RMSNorm, which is where you scale the bector by the variance without centering it. In addition, while Vaswani applies normalization to the output of the attention layer (post-normalization), Llama applies it to the inputs before (pre-normalization).
class RMSNorm(nn.Module):
def __init__(self, layer_shape, eps=1e-8, bias=False):
super(RMSNorm, self).__init__()
self.register_parameter("scale", nn.Parameter(torch.ones(layer_shape)))
def forward(self, x):
"""
assumes shape is (batch, seq_len, d_model)
"""
# frob norm is not the same as RMS. RMS = 1/sqrt(N) * frob norm
ff_rms = torch.linalg.norm(x, dim=(1,2)) * x[0].numel() ** -.5
raw = x / ff_rms.unsqueeze(-1).unsqueeze(-1)
return self.scale[:x.shape[1], :].unsqueeze(0) * raw
config = {
'batch_size': 5,
'context_window': 11,
'd_model': 13,
}
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
m = RMSNorm((config['context_window'], config['d_model']))
g = m(batch)
print(g.shape)
torch.Size([5, 11, 13])
We want to test to ensure that the RMSNorm is doing what we think it should. We can do this the old-fashioned way: row-wise comparisons. The RMSNorm has the property where the norm of the layer will be the square root of the number of elements in the layer, so we can check that for every layer.
rms = torch.linalg.norm(batch, dim=(1,2)) * (batch[0].numel() ** -.5)
# scaled_batch.var(dim=(1,2))
assert torch.linalg.norm( torch.arange(5).float() ) == (torch.arange(5).float() ** 2 ).sum() ** .5
rms = torch.linalg.norm( torch.arange(5).float() ) * (torch.arange(5).numel() ** -.5)
assert torch.allclose(torch.linalg.norm(torch.arange(5).float() / rms), torch.tensor(5 ** .5))
ff_rms = torch.linalg.norm(batch, dim=(1,2)) * batch.shape[1:].numel() ** -.5
# RMS for sure
ffx = torch.zeros_like(batch)
for i in range(batch.shape[0]):
ffx[i] = batch[i] / ff_rms[i]
assert torch.allclose(torch.linalg.norm(ffx, dim=(1,2)) ** 2, torch.tensor(143).float())
assert torch.allclose(ffx, g)
Alright, so that's RMSNorm, and it seems like it's working. Again, let's test it out.
class SimpleModel_RMS(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
self.rms = RMSNorm((config['context_window'], config['d_model']))
self.linear = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
nn.ReLU(),
nn.Linear(config['d_model'], config['vocab_size']),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embedding(idx)
x = self.rms(x) # rms pre-normalization
logits = self.linear(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
else:
return logits
model = SimpleModel_RMS(MASTER_CONFIG)
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = model(xs, ys)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer)
model params: 35265
validation loss: 2.4792869329452514
<Axes: >
So RMSNorm works, and it got our loss down by a small amount.
RoPE is a kind of positional encoding for transformers. In Attention is All You Need, the authors propose two kinds of positional encodings, learned and fixed. In RoPE, the authors propose embedding the position of a token in a sequence by rotating the embedding, with a different rotation at each position.
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
K = 3
config = {
'batch_size': 10,
'd_model': 32,
'n_heads': 8,
'context_window': K**2,
}
batch = torch.randn(1, config['context_window'], config['d_model'])
R = get_rotary_matrix(config['context_window'], config['d_model'])
fig, ax = plt.subplots(K, K, figsize=(K * 3, K * 4))
for i in range(K):
for j in range(K):
ax[i, j].imshow(R[i * K + j, :, :].detach().numpy())
ax[i, j].set_title(f'rotation at {i * K + j}')
Let's make sure these work. They should exhibit the quality that $$ q_m^T k_n = (R^d_{\Theta, m}W_q x_m)^T (R^d_{\Theta, n} W_k x_n) = x^T W_q R^d_{\Theta, n-m} W_k x_n. $$
config = {
'd_model': 128,
'context_window': 16,
}
R = get_rotary_matrix(config['context_window'], config['d_model'])
x = torch.randn(config['d_model'])
y = torch.randn(config['d_model'])
m = 3
n = 13
x_m = R[m,:,:] @ x
x_n = R[n,:,:] @ y
assert torch.isclose(x_m @ x_n, x @ R[n-m,:,:] @ y)
So the RoPE rotations work as expected.
config = {
'batch_size': 10,
'd_model': 512,
'n_heads': 8,
'context_window': 16,
}
class RoPEAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.w_q = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.w_k = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.w_v = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.multihead = nn.MultiheadAttention(config['d_model'], config['n_heads'], dropout=0.1, batch_first=True)
self.R = get_rotary_matrix(config['context_window'], config['d_model'])
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_out = (torch.bmm(q.transpose(0,1), self.R)).transpose(0,1)
k_out = (torch.bmm(k.transpose(0,1), self.R)).transpose(0,1)
v_out = (torch.bmm(v.transpose(0,1), self.R)).transpose(0,1)
activations, attn_weights = self.multihead(
q_out,k_out,v_out,
)
if return_attn_weights:
return activations, attn_weights
return activations
layer = RoPEAttention(config)
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
output, attn_weights = layer(batch, return_attn_weights=True)
Tip here: know the difference between tensor dimensions at train time vs tensor dimensions at inference time.
Although at train time, you can expect your tensor dimensions to match your model parameters closely, eg batch.shape = (config['batch_size'], config['context_window'], config['d_model'])
, at inference time, you may have to deal with a single example, eg batch.shape = (1, 1, config['d_model'])
. For this reason, you need to make sure that when you're indexing in the forward
pass, you're indexing using shapes derived from the input, not necessarily the model parameters.
Let's make sure it does what we think it does. For this layer, we're going to want to test three things:
- that it rotates embeddings the way we think it does
- that the attention mask used for causal attention is working properly.
x = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
q = layer.w_q(x)
k = layer.w_k(x)
v = layer.w_v(x)
q_rotated = torch.zeros_like(x)
k_rotated = torch.zeros_like(x)
v_rotated = torch.zeros_like(x)
for position in range(config['context_window']):
q_rotated[:,position,:] = torch.matmul(q[:,position,:], layer.R[position,:,:])
k_rotated[:,position,:] = torch.matmul(k[:,position,:], layer.R[position,:,:])
v_rotated[:,position,:] = torch.matmul(v[:,position,:], layer.R[position,:,:])
q_out = (torch.bmm(q.transpose(0,1), layer.R)).transpose(0,1)
k_out = (torch.bmm(k.transpose(0,1), layer.R)).transpose(0,1)
v_out = (torch.bmm(v.transpose(0,1), layer.R)).transpose(0,1)
assert torch.allclose(q.transpose(0,1)[0], q[:,0,:])
assert torch.allclose(q.transpose(0,1)[0] @ layer.R[0], q[:,0,:] @ layer.R[0])
assert torch.allclose(q_rotated, q_out)
config = {
'batch_size': 1,
'd_model': 2,
'n_heads': 2,
'context_window': 3,
}
layer = RoPEAttention(config)
batch = torch.ones((config['batch_size'], config['context_window'], config['d_model']))
output, attn_weights = layer(batch, return_attn_weights=True)
m = 0
x_q = batch[0, m]
q = layer.R[m,:,:] @ layer.w_q(x_q)
assert torch.allclose(layer.w_q(x_q), layer.w_q.weight @ x_q)
assert torch.allclose(q, layer.R[m, :, :] @ layer.w_q.weight @ x_q)
n = 2
x_k = batch[0, n]
k = layer.R[n,:,:] @ layer.w_k(x_k)
assert torch.allclose(layer.w_k(x_k), layer.w_k.weight @ x_k)
assert torch.allclose(k, layer.R[n, :, :] @ layer.w_k.weight @ x_k)
assert q.T @ k == q @ k # transpose is redundant
assert torch.allclose(q @ k, x_k.T @ layer.w_k.weight.T @ layer.R[n, :, :].T @ layer.R[m, :, :] @ layer.w_q.weight @ x_q)
assert torch.allclose(q @ k, x_k.T @ layer.w_k.weight.T @ layer.R[n-m, :, :].T @ layer.w_q.weight @ x_q)
/var/folders/w4/2j887mvs097bkhhjpgfzjlyr0000gn/T/ipykernel_17478/2550954139.py:26: UserWarning: The use of `x.T` on tensors of dimension other than 2 to reverse their shape is deprecated and it will throw an error in a future release. Consider `x.mT` to transpose batches of matrices or `x.permute(*torch.arange(x.ndim - 1, -1, -1))` to reverse the dimensions of a tensor. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/native/TensorShape.cpp:3575.)
assert q.T @ k == q @ k # transpose is redundant
Now let's inspect the attention weights. Since this is causal, we would expect that due to masking, the upper triangular of the attention should be 0.
MASTER_CONFIG.update({
'n_heads': 8,
})
layer = RoPEAttention(MASTER_CONFIG)
batch = torch.ones((MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'], MASTER_CONFIG['d_model']))
output, attn_weights = layer(batch, return_attn_weights=True)
plt.imshow(attn_weights[0].detach().numpy(), interpolation='nearest')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x16c039090>
This is not good; it means that information is leaking across the attention. We need to ensure the causal mask is working.
class RoPEAttention_wMask(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.w_q = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.w_k = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.w_v = nn.Linear(config['d_model'], config['d_model'], bias=False)
self.multihead = nn.MultiheadAttention(config['d_model'], config['n_heads'], dropout=0.1, batch_first=True)
self.R = get_rotary_matrix(config['context_window'], config['d_model'])
def get_rotary_matrix(context_window, embedding_dim):
R = torch.zeros((context_window, embedding_dim, embedding_dim), requires_grad=False)
for position in range(context_window):
for i in range(embedding_dim//2):
theta = 10000. ** (-2.*(i - 1) / embedding_dim)
m_theta = position * theta
R[position, 2*i,2*i] = np.cos(m_theta)
R[position, 2*i,2*i+1] = - np.sin(m_theta)
R[position, 2*i+1,2*i] = np.sin(m_theta)
R[position, 2*i+1,2*i+1] = np.cos(m_theta)
return R
def forward(self, x, return_attn_weights=False):
b,m,d = x.shape
q = self.w_q(x)
k = self.w_k(x)
v = self.w_v(x)
q_out = (torch.bmm(q.transpose(0,1), self.R[:m, ...])).transpose(0,1)
k_out = (torch.bmm(k.transpose(0,1), self.R[:m, ...])).transpose(0,1)
v_out = (torch.bmm(v.transpose(0,1), self.R[:m, ...])).transpose(0,1)
activations, attn_weights = self.multihead(
q_out,k_out,v_out,
attn_mask=nn.Transformer.generate_square_subsequent_mask(m),
is_causal=True
)
if return_attn_weights:
return activations, attn_weights
return activations
layer = RoPEAttention(config)
batch = torch.randn((config['batch_size'], config['context_window'], config['d_model']))
output, attn_weights = layer(batch, return_attn_weights=True)
layer = RoPEAttention_wMask(MASTER_CONFIG)
batch = torch.ones((MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'], MASTER_CONFIG['d_model']))
output, attn_weights = layer(batch, return_attn_weights=True)
plt.imshow(attn_weights[0].detach().numpy())
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x16c2c7b50>
Alright, let's run it and see what happens.
class RopeModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
self.rms = RMSNorm((config['context_window'], config['d_model']))
self.rope_attention = RoPEAttention_wMask(config)
self.linear = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
nn.ReLU(),
)
self.last_linear = nn.Linear(config['d_model'], config['vocab_size'])
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embedding(idx)
# one block of attention
x = self.rms(x) # rms pre-normalization
x = x + self.rope_attention(x)
x = self.rms(x) # rms pre-normalization
x = x + self.linear(x)
logits = self.last_linear(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
else:
return logits
model = RopeModel(MASTER_CONFIG)
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = model(xs, ys)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer)
model params: 150465
validation loss: 2.1157416343688964
<Axes: >
It looks like we can drive our loss down even lower. Let's do that by updating master config.
MASTER_CONFIG.update({
"epochs": 5000,
"log_interval": 10,
})
train(model, optimizer)
validation loss: 1.9027801871299743
<Axes: >
As it says in the paper, "We replace the ReLU non-linearity by the SwiGLU activation function...we use a dimension of
class SwiGLU(nn.Module):
"""
Swish-Gated Linear Unit
https://arxiv.org/pdf/2002.05202v1.pdf
"""
def __init__(self, size):
super().__init__()
self.config = config
self.linear_gate = nn.Linear(size, size)
self.linear = nn.Linear(size, size)
self.beta = torch.randn(1, requires_grad=True)
self.beta = nn.Parameter(torch.ones(1))
self.register_parameter("beta", self.beta)
def forward(self, x):
swish_gate = self.linear_gate(x) * torch.sigmoid(self.beta * self.linear_gate(x))
out = swish_gate * self.linear(x)
return out
class RopeModel(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embedding = nn.Embedding(config['vocab_size'], config['d_model'])
self.rms = RMSNorm((config['context_window'], config['d_model']))
self.rope_attention = RoPEAttention_wMask(config)
self.linear = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
SwiGLU(config['d_model']),
)
self.last_linear = nn.Linear(config['d_model'], config['vocab_size'])
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embedding(idx)
# one block of attention
x = self.rms(x) # rms pre-normalization
x = x + self.rope_attention(x)
x = self.rms(x) # rms pre-normalization
x = x + self.linear(x)
logits = self.last_linear(x)
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
else:
return logits
model = RopeModel(MASTER_CONFIG)
xs, ys = get_batches(dataset, 'train', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = model(xs, ys)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer)
model params: 183490
validation loss: 1.8666490197181702
<Axes: >
Now, let's add multiple layers of RopeAttention by creating blocks.
# add RMSNorm and residual conncection
class LlamaBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.rms = RMSNorm((config['context_window'], config['d_model']))
self.attention = RoPEAttention_wMask(config)
self.feedforward = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
SwiGLU(config['d_model']),
)
def forward(self, x):
x = self.rms(x) # rms pre-normalization
x = x + self.attention(x)
x = self.rms(x) # rms pre-normalization
x = x + self.feedforward(x)
return x
block = LlamaBlock(MASTER_CONFIG)
block(torch.randn(MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'], MASTER_CONFIG['d_model']));
from collections import OrderedDict
MASTER_CONFIG.update({
'n_layers': 4,
})
class Llama(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embeddings = nn.Embedding(config['vocab_size'], config['d_model'])
self.llama_blocks = nn.Sequential(
OrderedDict([(f"llama_{i}", LlamaBlock(config)) for i in range(config['n_layers'])])
)
self.ffn = nn.Sequential(
nn.Linear(config['d_model'], config['d_model']),
SwiGLU(config['d_model']),
nn.Linear(config['d_model'], config['vocab_size']),
)
print("model params:", sum([m.numel() for m in self.parameters()]))
def forward(self, idx, targets=None):
x = self.embeddings(idx)
x = self.llama_blocks(x)
logits = self.ffn(x)
if targets is None:
return logits
else:
loss = F.cross_entropy(logits.view(-1, self.config['vocab_size']), targets.view(-1))
return logits, loss
llama = Llama(MASTER_CONFIG)
optimizer = torch.optim.Adam(llama.parameters())
train(llama, optimizer, config=MASTER_CONFIG)
model params: 733382
validation loss: 1.616115140914917
<Axes: >
It looks like we can drive the loss down even more, and although we're overfitting a little, I think we can still do better. Let's train longer.
MASTER_CONFIG.update({
'epochs': 10000,
})
train(llama, optimizer, scheduler=None, config=MASTER_CONFIG)
validation loss: 0.9024032115936279
<Axes: >
It seems we can go even lower, still without serious overfitting. Either there is a leak, or it's actually doing well. The loss here is 1.08, which is equivalent to choosing between 2.9 tokens randomly.
train(llama, optimizer, config=MASTER_CONFIG)
validation loss: 0.746810007095337
<Axes: >
print(generate(llama, MASTER_CONFIG, 500)[0])
Evend her break of thou thire xoing dieble had side, did foesors exenatedH in siffied up,
No, none,
And you ling as thought depond.
MENENIUS:
Tell officien:
To pesiding be
Best wanty and to spiege,
To uncine shee patss again,
I will hen: then they
Moieth:
I my cast in letch:
For bereful, give toan I may
LINT OF AUMERLE:
Out, or me but thee here sir,
Why first with canse pring;
Now!
Gide me couuse
The haster:
And suilt harming,
Then as pereise with and go.
FROMNIUS:
I well? speak and wieke ac
At this point, we've hit the bottom with our training. Let's test on the test set.
xs, ys = get_batches(dataset, 'test', MASTER_CONFIG['batch_size'], MASTER_CONFIG['context_window'])
logits, loss = llama(xs, ys)
print(loss)
tensor(0.8304, grad_fn=<NllLossBackward0>)
Let's inspect the gradients, we want to see how they're flowing. If there are too many gradients where the value is close to 0, that's a problem.
# print the percentage that are near 0
def show_grads(model, tol=1e-2):
return sorted([(name, 100.0 * float(torch.sum(torch.abs(param) <= tol)) / float(param.nelement())) for name, param in model.named_parameters() if param.requires_grad], key=lambda t: t[1], reverse=True)
show_grads(llama)
[('llama_blocks.llama_0.attention.multihead.in_proj_bias', 36.71875),
('llama_blocks.llama_3.attention.multihead.in_proj_bias', 35.9375),
('llama_blocks.llama_1.attention.multihead.in_proj_bias', 33.59375),
('llama_blocks.llama_2.attention.multihead.in_proj_bias', 33.333333333333336),
('llama_blocks.llama_0.attention.multihead.in_proj_weight',
21.840413411458332),
('llama_blocks.llama_0.attention.w_q.weight', 14.892578125),
('llama_blocks.llama_0.attention.multihead.out_proj.weight', 13.4765625),
('llama_blocks.llama_0.attention.w_k.weight', 12.9638671875),
('llama_blocks.llama_0.attention.w_v.weight', 11.8896484375),
('llama_blocks.llama_1.attention.w_v.weight', 11.285400390625),
('llama_blocks.llama_3.attention.multihead.out_proj.weight', 11.12060546875),
('llama_blocks.llama_2.attention.w_v.weight', 10.68115234375),
('llama_blocks.llama_0.feedforward.0.weight', 10.4248046875),
('llama_blocks.llama_2.attention.multihead.out_proj.weight', 10.36376953125),
('llama_blocks.llama_1.attention.multihead.out_proj.weight', 10.2783203125),
('llama_blocks.llama_3.attention.multihead.out_proj.bias', 10.15625),
('llama_blocks.llama_3.attention.w_v.weight', 9.991455078125),
('llama_blocks.llama_0.feedforward.1.linear.weight', 9.9609375),
('llama_blocks.llama_1.attention.multihead.in_proj_weight',
9.908040364583334),
('llama_blocks.llama_2.attention.multihead.in_proj_weight', 9.75341796875),
('llama_blocks.llama_0.feedforward.1.linear_gate.weight', 9.66796875),
('llama_blocks.llama_1.attention.multihead.out_proj.bias', 9.375),
('llama_blocks.llama_3.attention.multihead.in_proj_weight', 9.16748046875),
('llama_blocks.llama_1.feedforward.0.weight', 8.77685546875),
('llama_blocks.llama_0.feedforward.1.linear_gate.bias', 8.59375),
('llama_blocks.llama_2.feedforward.1.linear.bias', 8.59375),
('llama_blocks.llama_2.feedforward.0.weight', 8.4228515625),
('llama_blocks.llama_1.feedforward.1.linear.weight', 7.720947265625),
('llama_blocks.llama_1.feedforward.1.linear_gate.weight', 7.501220703125),
('llama_blocks.llama_3.feedforward.0.weight', 7.440185546875),
('llama_blocks.llama_2.feedforward.1.linear.weight', 7.31201171875),
('llama_blocks.llama_2.feedforward.1.linear_gate.weight', 7.196044921875),
('llama_blocks.llama_1.feedforward.1.linear_gate.bias', 7.03125),
('llama_blocks.llama_2.attention.multihead.out_proj.bias', 7.03125),
('llama_blocks.llama_2.feedforward.1.linear_gate.bias', 7.03125),
('llama_blocks.llama_2.attention.w_k.weight', 6.94580078125),
('llama_blocks.llama_3.feedforward.1.linear.weight', 6.927490234375),
('llama_blocks.llama_1.attention.w_k.weight', 6.82373046875),
('llama_blocks.llama_2.attention.w_q.weight', 6.82373046875),
('llama_blocks.llama_3.attention.w_k.weight', 6.585693359375),
('llama_blocks.llama_3.feedforward.1.linear_gate.weight', 6.494140625),
('llama_blocks.llama_1.attention.w_q.weight', 6.396484375),
('llama_blocks.llama_1.feedforward.1.linear.bias', 6.25),
('llama_blocks.llama_3.feedforward.1.linear.bias', 6.25),
('llama_blocks.llama_3.attention.w_q.weight', 6.2255859375),
('ffn.0.weight', 5.633544921875),
('llama_blocks.llama_0.feedforward.1.linear.bias', 5.46875),
('ffn.1.linear_gate.bias', 5.46875),
('ffn.1.linear_gate.weight', 5.2978515625),
('ffn.1.linear.weight', 5.26123046875),
('ffn.2.weight', 4.447115384615385),
('ffn.1.linear.bias', 3.90625),
('llama_blocks.llama_0.feedforward.0.bias', 3.125),
('ffn.2.bias', 3.076923076923077),
('llama_blocks.llama_3.feedforward.0.bias', 2.34375),
('ffn.0.bias', 2.34375),
('llama_blocks.llama_0.attention.multihead.out_proj.bias', 1.5625),
('llama_blocks.llama_2.feedforward.0.bias', 1.5625),
('llama_blocks.llama_3.feedforward.1.linear_gate.bias', 1.5625),
('llama_blocks.llama_1.feedforward.0.bias', 0.78125),
('embeddings.weight', 0.7572115384615384),
('llama_blocks.llama_0.rms.scale', 0.146484375),
('llama_blocks.llama_0.feedforward.1.beta', 0.0),
('llama_blocks.llama_1.rms.scale', 0.0),
('llama_blocks.llama_1.feedforward.1.beta', 0.0),
('llama_blocks.llama_2.rms.scale', 0.0),
('llama_blocks.llama_2.feedforward.1.beta', 0.0),
('llama_blocks.llama_3.rms.scale', 0.0),
('llama_blocks.llama_3.feedforward.1.beta', 0.0),
('ffn.1.beta', 0.0)]
Here, for all of our parameter gradients, the vast majority are non-zero, which is great. If we start to see this number peak higher, then our gradients would not be flowing.
In the original Llama paper, the authors use Cosine Annealing learning schedule. We didn't do that here, because I experimented and saw that it was worse.
MASTER_CONFIG.update({
"epochs": 1000
})
llama_with_cosine = Llama(MASTER_CONFIG)
llama_optimizer = torch.optim.Adam(
llama.parameters(),
betas=(.9, .95),
weight_decay=.1,
eps=1e-9,
lr=1e-3
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(llama_optimizer, 300, eta_min=1e-5)
train(llama_with_cosine, llama_optimizer, scheduler=scheduler)
model params: 733382
/Users/bkitano/Desktop/projects/llama/.llama/lib/python3.11/site-packages/torch/optim/lr_scheduler.py:814: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.
warnings.warn("To get the last learning rate computed by the scheduler, "
lr: [0.0009999457184159408]
lr: [0.0009961510274583004]
lr: [0.0009869757772816292]
lr: [0.0009725204933511963]
lr: [0.0009529435502760634]
lr: [0.0009284594366176498]
lr: [0.0008993364049014041]
lr: [0.0008658935325782156]
lr: [0.0008284972261358946]
lr: [0.0007875572066618724]
lr: [0.0007435220208394261]
lr: [0.0006968741265588948]
lr: [0.0006481246069855613]
lr: [0.0005978075709959357]
lr: [0.0005464743013294258]
lr: [0.0004946872145648923]
lr: [0.00044301369909075325]
lr: [0.00039201989856974814]
lr: [0.0003422645089892466]
lr: [0.0002942926572257539]
lr: [0.0002486299281339155]
lr: [0.0002057766054928797]
lr: [0.00016620218969383484]
lr: [0.00013034025177986015]
lr: [9.858367916643242e-05]
lr: [7.128036241775617e-05]
lr: [4.872936226262451e-05]
lr: [3.117756953567661e-05]
lr: [1.8816750064937722e-05]
lr: [1.1779739336260682e-05]
lr: [1.0054282328142978e-05]
lr: [1.396862020283371e-05]
lr: [2.3138198807489664e-05]
lr: [3.7591595433673565e-05]
lr: [5.716757986058726e-05]
lr: [8.165110310035915e-05]
lr: [0.00011077372717831143]
lr: [0.00014421629539013414]
lr: [0.00018161236159547207]
lr: [0.00022255218257041353]
lr: [0.000266587198202317]
lr: [0.0003132349418762423]
lr: [0.0003619843243986437]
lr: [0.00041230123246673267]
lr: [0.0004636343798727907]
lr: [0.0005154213471150504]
lr: [0.0005670947431326733]
lr: [0.0006180884215964123]
lr: [0.0006678436836148733]
lr: [0.0007158153988789833]
lr: [0.000761477978168448]
lr: [0.0008043311317769176]
lr: [0.0008439053507606739]
lr: [0.0008797670509540533]
lr: [0.0009115233233906334]
lr: [0.0009388262390821259]
lr: [0.0009613766609899272]
lr: [0.0009789275214238276]
lr: [0.000991286528959492]
lr: [0.0009983182752167242]
lr: [0.0009999457184159228]
lr: [0.0009961510274582826]
lr: [0.0009869757772816119]
lr: [0.0009725204933511793]
lr: [0.0009529435502760463]
lr: [0.0009284594366176337]
lr: [0.0008993364049013882]
lr: [0.0008658935325782011]
lr: [0.0008284972261358803]
lr: [0.000787557206661859]
lr: [0.0007435220208394123]
lr: [0.0006968741265588827]
lr: [0.0006481246069855505]
lr: [0.0005978075709959258]
lr: [0.0005464743013294155]
lr: [0.0004946872145648846]
lr: [0.0004430136990907463]
lr: [0.00039201989856974174]
lr: [0.0003422645089892402]
lr: [0.00029429265722574944]
lr: [0.0002486299281339118]
lr: [0.00020577660549287596]
lr: [0.00016620218969383194]
lr: [0.00013034025177985776]
lr: [9.85836791664316e-05]
lr: [7.128036241775529e-05]
lr: [4.872936226262421e-05]
lr: [3.117756953567644e-05]
lr: [1.8816750064937634e-05]
lr: [1.1779739336260654e-05]
lr: [1.0054282328142978e-05]
lr: [1.396862020283372e-05]
lr: [2.3138198807489698e-05]
lr: [3.759159543367397e-05]
lr: [5.716757986058744e-05]
lr: [8.165110310035937e-05]
lr: [0.00011077372717831269]
lr: [0.00014421629539013508]
lr: [0.0001816123615954725]
lr: [0.00022255218257041364]
validation loss: 4.179551410675049
<Axes: >
show_grads(llama_with_cosine, 1e-5)
[('llama_blocks.llama_0.attention.multihead.in_proj_bias', 100.0),
('llama_blocks.llama_0.attention.multihead.out_proj.bias', 100.0),
('llama_blocks.llama_1.attention.multihead.in_proj_bias', 100.0),
('llama_blocks.llama_1.attention.multihead.out_proj.bias', 100.0),
('llama_blocks.llama_2.attention.multihead.in_proj_bias', 100.0),
('llama_blocks.llama_2.attention.multihead.out_proj.bias', 100.0),
('llama_blocks.llama_3.attention.multihead.in_proj_bias', 100.0),
('llama_blocks.llama_3.attention.multihead.out_proj.bias', 100.0),
('llama_blocks.llama_1.feedforward.1.linear.bias', 0.78125),
('llama_blocks.llama_2.feedforward.1.linear_gate.weight', 0.030517578125),
('llama_blocks.llama_3.attention.w_q.weight', 0.030517578125),
('llama_blocks.llama_0.attention.multihead.out_proj.weight', 0.0244140625),
('llama_blocks.llama_0.feedforward.1.linear_gate.weight', 0.0244140625),
('llama_blocks.llama_0.feedforward.1.linear.weight', 0.0244140625),
('llama_blocks.llama_1.attention.w_v.weight', 0.0244140625),
('llama_blocks.llama_1.feedforward.1.linear.weight', 0.0244140625),
('llama_blocks.llama_2.attention.w_q.weight', 0.0244140625),
('llama_blocks.llama_3.attention.w_v.weight', 0.0244140625),
('llama_blocks.llama_3.attention.multihead.out_proj.weight', 0.0244140625),
('ffn.2.weight', 0.02403846153846154),
('llama_blocks.llama_0.attention.w_v.weight', 0.018310546875),
('llama_blocks.llama_0.feedforward.0.weight', 0.018310546875),
('llama_blocks.llama_1.feedforward.0.weight', 0.018310546875),
('llama_blocks.llama_3.attention.w_k.weight', 0.018310546875),
('llama_blocks.llama_3.feedforward.0.weight', 0.018310546875),
('llama_blocks.llama_0.attention.multihead.in_proj_weight',
0.016276041666666668),
('llama_blocks.llama_1.attention.multihead.out_proj.weight', 0.01220703125),
('llama_blocks.llama_1.feedforward.1.linear_gate.weight', 0.01220703125),
('llama_blocks.llama_3.feedforward.1.linear_gate.weight', 0.01220703125),
('llama_blocks.llama_3.feedforward.1.linear.weight', 0.01220703125),
('llama_blocks.llama_0.attention.w_q.weight', 0.006103515625),
('llama_blocks.llama_0.attention.w_k.weight', 0.006103515625),
('llama_blocks.llama_1.attention.w_q.weight', 0.006103515625),
('llama_blocks.llama_1.attention.multihead.in_proj_weight', 0.006103515625),
('llama_blocks.llama_2.attention.multihead.in_proj_weight', 0.006103515625),
('llama_blocks.llama_2.attention.multihead.out_proj.weight', 0.006103515625),
('llama_blocks.llama_2.feedforward.1.linear.weight', 0.006103515625),
('llama_blocks.llama_3.attention.multihead.in_proj_weight',
0.004069010416666667),
('embeddings.weight', 0.0),
('llama_blocks.llama_0.rms.scale', 0.0),
('llama_blocks.llama_0.feedforward.0.bias', 0.0),
('llama_blocks.llama_0.feedforward.1.beta', 0.0),
('llama_blocks.llama_0.feedforward.1.linear_gate.bias', 0.0),
('llama_blocks.llama_0.feedforward.1.linear.bias', 0.0),
('llama_blocks.llama_1.rms.scale', 0.0),
('llama_blocks.llama_1.attention.w_k.weight', 0.0),
('llama_blocks.llama_1.feedforward.0.bias', 0.0),
('llama_blocks.llama_1.feedforward.1.beta', 0.0),
('llama_blocks.llama_1.feedforward.1.linear_gate.bias', 0.0),
('llama_blocks.llama_2.rms.scale', 0.0),
('llama_blocks.llama_2.attention.w_k.weight', 0.0),
('llama_blocks.llama_2.attention.w_v.weight', 0.0),
('llama_blocks.llama_2.feedforward.0.weight', 0.0),
('llama_blocks.llama_2.feedforward.0.bias', 0.0),
('llama_blocks.llama_2.feedforward.1.beta', 0.0),
('llama_blocks.llama_2.feedforward.1.linear_gate.bias', 0.0),
('llama_blocks.llama_2.feedforward.1.linear.bias', 0.0),
('llama_blocks.llama_3.rms.scale', 0.0),
('llama_blocks.llama_3.feedforward.0.bias', 0.0),
('llama_blocks.llama_3.feedforward.1.beta', 0.0),
('llama_blocks.llama_3.feedforward.1.linear_gate.bias', 0.0),
('llama_blocks.llama_3.feedforward.1.linear.bias', 0.0),
('ffn.0.weight', 0.0),
('ffn.0.bias', 0.0),
('ffn.1.beta', 0.0),
('ffn.1.linear_gate.weight', 0.0),
('ffn.1.linear_gate.bias', 0.0),
('ffn.1.linear.weight', 0.0),
('ffn.1.linear.bias', 0.0),
('ffn.2.bias', 0.0)]
Even at an extremely low tolerance, the attention biases are not getting any signal. I'm not sure why the learning schedule from the paper doesn't work, but the lesson here is simple: start simple.