-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmegabyte.py
More file actions
468 lines (349 loc) · 15.1 KB
/
megabyte.py
File metadata and controls
468 lines (349 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
import math
import functools
from itertools import zip_longest
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, reduce, repeat, pack, unpack
from einops.layers.torch import Rearrange
from beartype import beartype
from beartype.typing import Tuple, Union
from attend import Attend
from tqdm import tqdm
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def remainder_to_mult(num, mult):
return (mult - num % mult) % mult
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def reduce_mult(nums):
return functools.reduce(lambda x, y: x * y, nums, 1)
# tensor helpers
def log(t, eps = 1e-20):
return torch.log(t.clamp(min = eps))
def gumbel_noise(t):
noise = torch.zeros_like(t).uniform_(0, 1)
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
def top_k(logits, thres = 0.5):
num_logits = logits.shape[-1]
k = max(int((1 - thres) * num_logits), 1)
val, ind = torch.topk(logits, k)
probs = torch.full_like(logits, float('-inf'))
probs.scatter_(1, ind, val)
return probs
# token shift, from Peng et al of RWKV
def token_shift(t):
t, t_shift = t.chunk(2, dim = -1)
t_shift = F.pad(t_shift, (0, 0, 1, -1))
return torch.cat((t, t_shift), dim = -1)
# rotary positional embedding
class RotaryEmbedding(nn.Module):
def __init__(self, dim, theta = 10000):
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
@property
def device(self):
return next(self.buffers()).device
def forward(self, seq_len):
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
freqs = torch.cat((freqs, freqs), dim = -1)
return freqs
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(pos, t):
return t * pos.cos() + rotate_half(t) * pos.sin()
# norm
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
return x / norm.clamp(min = self.eps) * self.g
# helper classes
def FeedForward(*, dim, mult = 4, dropout = 0.):
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, dim * mult),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)
class Attention(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
dropout = 0.,
flash = False
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.attend = Attend(
causal = True,
flash = flash,
dropout = dropout
)
self.dropout = nn.Dropout(dropout)
self.norm = RMSNorm(dim)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, rotary_emb = None):
h, device = self.heads, x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = h)
if exists(rotary_emb):
q, k = map(lambda t: apply_rotary_pos_emb(rotary_emb, t), (q, k))
out = self.attend(q, k, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
layers,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_dropout = 0.,
ff_mult = 4,
rel_pos = True,
flash_attn = False
):
super().__init__()
self.rotary_emb = RotaryEmbedding(dim_head) if rel_pos else None
self.layers = nn.ModuleList([])
for _ in range(layers):
self.layers.append(nn.ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = flash_attn),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim)
def forward(self, x):
n = x.shape[-2]
rotary_emb = self.rotary_emb(n) if exists(self.rotary_emb) else None
for attn, ff in self.layers:
x = attn(token_shift(x), rotary_emb = rotary_emb) + x
x = ff(token_shift(x)) + x
return self.norm(x)
# main class
class MEGABYTE(nn.Module):
@beartype
def __init__(
self,
*,
num_tokens,
dim: Union[Tuple, int],
depth: Tuple,
max_seq_len: Tuple,
dim_head = 64,
heads = 8,
attn_dropout = 0.,
ff_mult = 4,
ff_dropout = 0.,
pad_id = 0,
rel_pos = False,
pos_emb = False,
flash_attn = False
):
super().__init__()
# simplified configuration for each stage of the hierarchy
# depth = (2, 2, 4) would translate to depth 2 at first stage, depth 2 second stage, depth 4 third
# max_seq_len = (16, 8, 4) would translate to max sequence length of 16 at first stage, length of 8 at second stage, length of 4 for last
assert isinstance(depth, tuple) and isinstance(max_seq_len, tuple)
assert len(depth) == len(max_seq_len)
self.stages = len(depth)
dim = cast_tuple(dim, self.stages)
assert len(dim) == self.stages
coarsest_dim, *_, fine_dim = dim
self.max_seq_len = max_seq_len
self.start_tokens = nn.ParameterList([nn.Parameter(torch.randn(h_dim)) for h_dim, seq_len in zip(dim, max_seq_len)])
self.pos_embs = nn.ModuleList([nn.Embedding(seq_len, h_dim) for h_dim, seq_len in zip(dim, max_seq_len)]) if pos_emb else None
self.token_embs = nn.ModuleList([])
patch_size = 1
self.token_embs.append(nn.Embedding(num_tokens, fine_dim))
for dim_out, seq_len in zip(reversed(dim[:-1]), reversed(max_seq_len[1:])):
patch_size *= seq_len
self.token_embs.append(nn.Sequential(
nn.Embedding(num_tokens, fine_dim),
Rearrange('... r d -> ... (r d)'),
nn.LayerNorm(patch_size * fine_dim),
nn.Linear(patch_size * fine_dim, dim_out),
nn.LayerNorm(dim_out)
))
self.transformers = nn.ModuleList([])
self.to_next_transformer_projections = nn.ModuleList([])
for h_dim, next_h_dim, stage_depth, next_seq_len in zip_longest(dim, dim[1:], depth, max_seq_len[1:]):
self.transformers.append(Transformer(
dim = h_dim,
layers = stage_depth,
dim_head = dim_head,
heads = heads,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout,
ff_mult = ff_mult,
rel_pos = rel_pos,
flash_attn = flash_attn
))
proj = nn.Identity()
if exists(next_h_dim) and next_h_dim != dim:
proj = nn.Sequential(
Rearrange('b ... d -> b (...) d'),
nn.Linear(h_dim, next_h_dim * next_seq_len),
Rearrange('b m (n d) -> (b m) n d', n = next_seq_len)
)
self.to_next_transformer_projections.append(proj)
self.to_logits = nn.Linear(fine_dim, num_tokens)
self.pad_id = pad_id
# report number of parameters
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
self.apply(self._init_weights)
def get_num_params(self, non_embedding=True):
"""
Return the number of parameters in the model.
For non-embedding count (default), the position embeddings get subtracted.
The token embeddings would too, except due to the parameter sharing these
params are actually used as weights in the final layer, so we include them.
"""
n_params = sum([p.numel() for p in self.parameters()])
# if non_embedding:
# n_params -= self.transformer.wpe.weight.numel()
return n_params
# def _init_weights(self, module):
# if isinstance(module, nn.Linear):
# # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.006)
# if module.bias is not None:
# torch.nn.init.zeros_(module.bias)
# elif isinstance(module, nn.Embedding):
# # torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
# torch.nn.init.normal_(module.weight, mean=0.0, std=0.006)
def _init_weights(self, module):
# Weight initialisation from MEGABYTE paper, A.1 Training Details
if isinstance(module, nn.Linear): # or isinstance(module, nn.Embedding):
# Init with normal distribution
torch.nn.init.normal_(module.weight, mean=0.0, std=0.006)
# Truncate weights to lie within two standard deviations
with torch.no_grad():
module.weight[module.weight > 0.012] = 0.012
module.weight[module.weight < -0.012] = -0.012
# Bias is initialized to zero if exists
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
def generate(self, prime = None, filter_thres = 0.9, temperature = 1., default_batch_size = 1):
total_seq_len = reduce_mult(self.max_seq_len)
device = next(self.parameters()).device
if not exists(prime):
prime = torch.empty((default_batch_size, 0), dtype = torch.long, device = device)
seq = prime
batch = seq.shape[0]
for _ in tqdm(range(total_seq_len - seq.shape[-1])):
logits = self.forward(seq)[:, -1]
logits = top_k(logits, thres = filter_thres)
sampled = gumbel_sample(logits, dim = -1, temperature = temperature)
seq = torch.cat((seq, rearrange(sampled, 'b -> b 1')), dim = -1)
return seq.reshape(batch, *self.max_seq_len)
def forward_empty(self, batch_size):
# take care of special case
# where you sample from input of 0 (start token only)
prev_stage_tokens_repr = None
for stage_start_tokens, transformer, proj in zip(self.start_tokens, self.transformers, self.to_next_transformer_projections):
tokens = repeat(stage_start_tokens, 'd -> b 1 d', b = batch_size)
if exists(prev_stage_tokens_repr):
tokens = tokens + prev_stage_tokens_repr[..., :tokens.shape[-2], :]
tokens = transformer(tokens)
prev_stage_tokens_repr = proj(tokens)
return self.to_logits(tokens)
def forward(self, ids, return_loss = False):
batch = ids.shape[0]
assert ids.ndim in {2, self.stages + 1}
flattened_dims = ids.ndim == 2
ids_orig_ndim = ids.ndim
if ids.numel() == 0:
return self.forward_empty(ids.shape[0])
if flattened_dims:
# allow for ids to be given in the shape of (batch, seq)
# in which case it will be auto-padded to the next nearest multiple of depth seq len
seq_len = ids.shape[-1]
multiple_of = reduce_mult(self.max_seq_len[1:])
padding = remainder_to_mult(seq_len, multiple_of)
ids = F.pad(ids, (0, padding), value = self.pad_id)
ids = ids.reshape(batch, -1, *self.max_seq_len[1:])
b, *prec_dims, device = *ids.shape, ids.device
# check some dimensions
assert prec_dims[0] <= self.max_seq_len[0], 'the first dimension of your axial autoregressive transformer must be less than the first tuple element of max_seq_len (like any autoregressive transformer)'
assert tuple(prec_dims[1:]) == tuple(self.max_seq_len[1:]), 'all subsequent dimensions must match exactly'
# get tokens for all hierarchical stages, reducing by appropriate dimensions
# and adding the absolute positional embeddings
tokens_at_stages = []
pos_embs = default(self.pos_embs, (None,))
for ind, pos_emb, token_emb in zip_longest(range(len(prec_dims)), pos_embs, self.token_embs):
is_first = ind == 0
tokens = token_emb(ids)
if exists(pos_emb):
positions = pos_emb(torch.arange(tokens.shape[-2], device = device))
tokens = tokens + positions
tokens_at_stages.insert(0, tokens)
if is_first:
continue
ids = rearrange(ids, '... m n -> ... (m n)')
# the un-pixelshuffled representations of the previous hierarchy, starts with None
prev_stage_tokens_repr = None
# spatial tokens is tokens with depth pos reduced along depth dimension + spatial positions
for stage_start_tokens, stage_tokens, transformer, proj in zip(self.start_tokens, tokens_at_stages, self.transformers, self.to_next_transformer_projections):
stage_tokens, ps = pack_one(stage_tokens, '* n d')
stage_start_tokens = repeat(stage_start_tokens, 'f -> b 1 f', b = stage_tokens.shape[0])
# concat start token
stage_tokens = torch.cat((
stage_start_tokens,
stage_tokens,
), dim = -2)
# sum the previous hierarchy's representation
if exists(prev_stage_tokens_repr):
prev_stage_tokens_repr = F.pad(prev_stage_tokens_repr, (0, 0, 1, 0), value = 0.)
stage_tokens = stage_tokens + prev_stage_tokens_repr
attended = transformer(stage_tokens)
attended = unpack_one(attended, ps, '* n d')
# project for next stage in the hierarchy
prev_stage_tokens_repr = proj(attended[..., :-1, :])
# project to logits
logits = self.to_logits(attended)
start_tokens = logits[(slice(None), *((0,) * (logits.ndim - 2)), slice(None))]
start_tokens = rearrange(start_tokens, 'b d -> b 1 d')
logits = logits[..., 1:, :]
if not return_loss:
if flattened_dims:
logits = rearrange(logits, 'b ... c -> b (...) c')
logits = logits[:, :seq_len]
return logits
logits = rearrange(logits, 'b ... c -> b (...) c')
logits = torch.cat((start_tokens, logits), dim = -2)
preds = rearrange(logits, 'b n c -> b c n')
labels = rearrange(ids, 'b ... -> b (...)')
loss = F.cross_entropy(
preds[..., :-1],
labels,
ignore_index = self.pad_id
)
return loss