Skip to content

Commit

Permalink
give tiny attention to Local SGU too, appropriately also sliding wind…
Browse files Browse the repository at this point in the history
…ow local attention
  • Loading branch information
lucidrains committed May 25, 2021
1 parent da23bcd commit efee809
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 7 deletions.
53 changes: 48 additions & 5 deletions g_mlp_gpt/g_mlp_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange
from einops import rearrange, repeat

from g_mlp_gpt.reversible import ReversibleSequence, SequentialSequence

Expand Down Expand Up @@ -99,6 +99,43 @@ def forward(self, x):
out = einsum('b i j, b j d -> b i d', attn, v)
return self.to_out(out)

class LocalAttention(nn.Module):
def __init__(self, dim_in, dim_inner, dim_out, window = 128):
super().__init__()
self.scale = dim_inner ** -0.5
self.window = window

self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False)
self.to_out = nn.Linear(dim_inner, dim_out)

def forward(self, x):
b, n, *_, device, w = *x.shape, x.device, self.window

x = pad_to_multiple(x, w, dim = -2, value = 0.)
q, k, v = self.to_qkv(x).chunk(3, dim = -1)

window_fn = lambda t: rearrange(t, 'b (w n) d -> b w n d', n = w)
q, k, v = map(window_fn, (q, k, v))

k, v = map(lambda t: F.pad(t, (0, 0, 0, 0, 1, 0)), (k, v))
k, v = map(lambda t: torch.cat((k[:, :-1], k[:, 1:]), dim = 2), (k, v))

sim = einsum('b w i d, b w j d -> b w i j', q, k) * self.scale
buckets, i, j = sim.shape[-3:]

mask_value = -torch.finfo(sim.dtype).max
mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool()
mask = repeat(mask, 'i j -> () u i j', u = buckets)

sim.masked_fill_(mask, mask_value)

attn = sim.softmax(dim = -1)

out = einsum('b w i j, b w j d -> b w i d', attn, v)
out = rearrange(out, 'b w n d -> b (w n) d')
out = self.to_out(out[:, :n])
return out

class CausalSGU(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -172,7 +209,7 @@ def __init__(
self.act = act
self.register_buffer('mask', ~torch.ones(window, window * 2).triu_(window + 1).bool())

def forward(self, x, **kwargs):
def forward(self, x, gate_res = None):
device, n, h, w = x.device, x.shape[1], self.heads, self.window

res, gate = x.chunk(2, dim = -1)
Expand All @@ -198,6 +235,9 @@ def forward(self, x, **kwargs):
gate = rearrange(gate, 'b w n d -> b (w n) d')
gate = gate[:, :n]

if exists(gate_res):
gate = gate + gate_res

return self.act(gate) * res

class AxiallyFold(nn.Module):
Expand Down Expand Up @@ -237,15 +277,18 @@ def __init__(
act = nn.Identity()
):
super().__init__()
SGU = partial(CausalLocalSGU, window = window) if exists(window) and window < seq_len else CausalSGU
is_windowed = exists(window) and window < seq_len

SGU_klass = partial(CausalLocalSGU, window = window) if is_windowed else CausalSGU
Attention_klass = partial(LocalAttention, window = window) if is_windowed else Attention

self.attn = Attention(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None
self.attn = Attention_klass(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None

self.proj_in = nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU()
)
self.sgu = SGU(dim_ff, seq_len, causal, heads = heads, act = act)
self.sgu = SGU_klass(dim_ff, seq_len, causal, heads = heads, act = act)
self.proj_out = nn.Linear(dim_ff // 2, dim)

def forward(self, x):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'g-mlp-gpt',
packages = find_packages(),
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'gMLP - GPT',
author = 'Phil Wang',
Expand Down
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def decode_tokens(tokens):
dim = 512,
seq_len = SEQ_LEN,
depth = 8,
window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN)
window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN),
attn_dim = 16
)

model = AutoregressiveWrapper(model)
Expand Down

0 comments on commit efee809

Please sign in to comment.