-
Notifications
You must be signed in to change notification settings - Fork 68
/
model.py
263 lines (223 loc) · 10.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
import os
import requests
import math
import tiktoken
import torch
import torch.nn as nn
from torch.nn import functional as F
# Hyperparameters
batch_size = 4 # How many batches per training step
context_length = 16 # Length of the token chunk each batch
d_model = 64 # The size of our model token embeddings
num_blocks = 8 # Number of transformer blocks
num_heads = 4 # Number of heads in Multi-head attention
learning_rate = 1e-3 # 0.001
dropout = 0.1 # Dropout rate
max_iters = 5000 # Total of training iterations <- Change this to smaller number for testing
eval_interval = 50 # How often to evaluate
eval_iters = 20 # Number of iterations to average for evaluation
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use GPU if it's available.
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)
# Load training data
if not os.path.exists('data/sales_textbook.txt'):
url = 'https://huggingface.co/datasets/goendalf666/sales-textbook_for_convincing_and_selling/raw/main/sales_textbook.txt'
with open('data/sales_textbook.txt', 'w') as f:
f.write(requests.get(url).text)
with open('data/sales_textbook.txt', 'r', encoding='utf-8') as f:
text = f.read()
# Using TikToken (Same as GPT3) to tokenize the source text
encoding = tiktoken.get_encoding("cl100k_base")
tokenized_text = encoding.encode(text)
max_token_value = max(tokenized_text) + 1 # the maximum value of the tokenized numbers
tokenized_text = torch.tensor(tokenized_text, dtype=torch.long, device=device) # put tokenized text into tensor
# Split train and validation
split_idx = int(len(tokenized_text) * 0.9)
train_data = tokenized_text[:split_idx]
val_data = tokenized_text[split_idx:]
# Define Feed Forward Network
class FeedForward(nn.Module):
def __init__(self):
super().__init__()
self.d_model = d_model
self.dropout = dropout
self.ffn = nn.Sequential(
nn.Linear(in_features=self.d_model, out_features=self.d_model * 4),
nn.ReLU(),
nn.Linear(in_features=self.d_model * 4, out_features=self.d_model),
nn.Dropout(dropout),
)
def forward(self, x):
return self.ffn(x)
# Define Scaled Dot Product Attention
class Attention(nn.Module):
def __init__(self, head_size: int):
super().__init__()
self.d_model = d_model
self.head_size = head_size
self.context_length = context_length
self.dropout = dropout
self.key_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
self.query_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
self.value_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
self.register_buffer('tril', torch.tril(
torch.ones((self.context_length, self.context_length)))) # Lower triangular mask
self.dropout_layer = nn.Dropout(self.dropout)
def forward(self, x):
B, T, C = x.shape # Batch size, Time steps(current context_length), Channels(dimensions)
assert T <= self.context_length
assert C == self.d_model
q = self.query_layer(x)
k = self.key_layer(x)
v = self.value_layer(x)
# Scaled dot product attention: Q @ K^T / sqrt(d_k)
weights = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# Apply masked attention
weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
weights = F.softmax(input=weights, dim=-1)
weights = self.dropout_layer(weights)
# Apply dot product attention: weights @ V
out = weights @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, head_size: int):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.d_model = d_model
self.context_length = context_length
self.dropout = dropout
self.heads = nn.ModuleList([Attention(head_size=self.head_size) for _ in range(self.num_heads)])
self.projection_layer = nn.Linear(in_features=self.d_model, out_features=self.d_model)
self.dropout_layer = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.projection_layer(out)
out = self.dropout_layer(out)
return out
class TransformerBlock(nn.Module):
def __init__(self, num_heads: int):
super().__init__()
self.d_model = d_model
self.context_length = context_length
self.head_size = d_model // num_heads # head size should be divisible by d_model
self.num_heads = num_heads
self.dropout = dropout
self.multi_head_attention_layer = MultiHeadAttention(head_size=self.head_size)
self.feed_forward_layer = FeedForward()
self.layer_norm_1 = nn.LayerNorm(normalized_shape=self.d_model)
self.layer_norm_2 = nn.LayerNorm(normalized_shape=self.d_model)
def forward(self, x):
# Note: The order of the operations is different from the original Transformer paper
# The order here is: LayerNorm -> Multi-head attention -> LayerNorm -> Feed forward
x = x + self.multi_head_attention_layer(self.layer_norm_1(x)) # Residual connection
x = x + self.feed_forward_layer(self.layer_norm_2(x)) # Residual connection
return x
class TransformerLanguageModel(nn.Module):
def __init__(self):
super().__init__()
self.d_model = d_model
self.context_length = context_length
self.num_heads = num_heads
self.num_blocks = num_blocks
self.dropout = dropout
self.max_token_value = max_token_value
# Set up token embedding look-up table
self.token_embedding_lookup_table = nn.Embedding(num_embeddings=self.max_token_value + 1, embedding_dim=self.d_model)
# Run all the transformer blocks
# Different from original paper, here we add a final layer norm after all the blocks
self.transformer_blocks = nn.Sequential(*(
[TransformerBlock(num_heads=self.num_heads) for _ in range(self.num_blocks)] +
[nn.LayerNorm(self.d_model)]
))
self.language_model_out_linear_layer = nn.Linear(in_features=self.d_model, out_features=self.max_token_value)
def forward(self, idx, targets=None):
B, T = idx.shape
"""
# Set up position embedding look-up table
# following the same approach as the original Transformer paper (Sine and Cosine functions)
"""
position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model)
position = torch.arange(0, self.context_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
# change position_encoding_lookup_table from (context_length, d_model) to (T, d_model)
position_embedding = position_encoding_lookup_table[:T, :].to(device)
x = self.token_embedding_lookup_table(idx) + position_embedding
x = self.transformer_blocks(x)
# The "logits" are the output values of our model before applying softmax
logits = self.language_model_out_linear_layer(x)
if targets is not None:
B, T, C = logits.shape
logits_reshaped = logits.view(B * T, C)
targets_reshaped = targets.view(B * T)
loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
else:
loss = None
return logits, loss
def generate(self, idx, max_new_tokens):
# idx is (B,T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop idx to the max size of our positional embeddings table
idx_crop = idx[:, -self.context_length:]
# Get predictions
logits, loss = self(idx_crop)
# Get the last time step from logits where the dimensions of the logits are (B,T,C)
logits_last_timestep = logits[:, -1, :]
# Apply softmax to get probabilities
probs = F.softmax(input=logits_last_timestep, dim=-1)
# Sample from the probabilities' distribution.
idx_next = torch.multinomial(input=probs, num_samples=1)
# Append the sampled indexes idx_next to idx
idx = torch.cat((idx, idx_next), dim=1)
return idx
# Initialize the model
model = TransformerLanguageModel()
model = model.to(device)
# Get input embedding batch
def get_batch(split: str):
data = train_data if split == 'train' else val_data
idxs = torch.randint(low=0, high=len(data) - context_length, size=(batch_size,))
x = torch.stack([data[idx:idx + context_length] for idx in idxs]).to(device)
y = torch.stack([data[idx + 1:idx + context_length + 1] for idx in idxs]).to(device)
return x, y
# Calculate loss
@torch.no_grad()
def estimate_loss():
out = {}
model.eval()
for split in ['train', 'valid']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
x_batch, y_batch = get_batch(split)
logits, loss = model(x_batch, y_batch)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
# Use AdamW optimizer
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)
tracked_losses = list()
for step in range(max_iters):
if step % eval_iters == 0 or step == max_iters - 1:
losses = estimate_loss()
tracked_losses.append(losses)
print('Step:', step, 'Training Loss:', round(losses['train'].item(), 3), 'Validation Loss:',
round(losses['valid'].item(), 3))
xb, yb = get_batch('train')
logits, loss = model(xb, yb)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# Save the model state dictionary
torch.save(model.state_dict(), 'model-ckpt.pt')
# Generate
model.eval()
start = 'The salesperson'
start_ids = encoding.encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
y = model.generate(x, max_new_tokens=100)
print('---------------')
print(encoding.decode(y[0].tolist()))
print('---------------')