diff --git a/g_mlp_gpt/g_mlp_gpt.py b/g_mlp_gpt/g_mlp_gpt.py index d984fd1..3813719 100644 --- a/g_mlp_gpt/g_mlp_gpt.py +++ b/g_mlp_gpt/g_mlp_gpt.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn, einsum -from einops import rearrange +from einops import rearrange, repeat from g_mlp_gpt.reversible import ReversibleSequence, SequentialSequence @@ -99,6 +99,43 @@ def forward(self, x): out = einsum('b i j, b j d -> b i d', attn, v) return self.to_out(out) +class LocalAttention(nn.Module): + def __init__(self, dim_in, dim_inner, dim_out, window = 128): + super().__init__() + self.scale = dim_inner ** -0.5 + self.window = window + + self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias = False) + self.to_out = nn.Linear(dim_inner, dim_out) + + def forward(self, x): + b, n, *_, device, w = *x.shape, x.device, self.window + + x = pad_to_multiple(x, w, dim = -2, value = 0.) + q, k, v = self.to_qkv(x).chunk(3, dim = -1) + + window_fn = lambda t: rearrange(t, 'b (w n) d -> b w n d', n = w) + q, k, v = map(window_fn, (q, k, v)) + + k, v = map(lambda t: F.pad(t, (0, 0, 0, 0, 1, 0)), (k, v)) + k, v = map(lambda t: torch.cat((k[:, :-1], k[:, 1:]), dim = 2), (k, v)) + + sim = einsum('b w i d, b w j d -> b w i j', q, k) * self.scale + buckets, i, j = sim.shape[-3:] + + mask_value = -torch.finfo(sim.dtype).max + mask = torch.ones(i, j, device = device).triu_(j - i + 1).bool() + mask = repeat(mask, 'i j -> () u i j', u = buckets) + + sim.masked_fill_(mask, mask_value) + + attn = sim.softmax(dim = -1) + + out = einsum('b w i j, b w j d -> b w i d', attn, v) + out = rearrange(out, 'b w n d -> b (w n) d') + out = self.to_out(out[:, :n]) + return out + class CausalSGU(nn.Module): def __init__( self, @@ -172,7 +209,7 @@ def __init__( self.act = act self.register_buffer('mask', ~torch.ones(window, window * 2).triu_(window + 1).bool()) - def forward(self, x, **kwargs): + def forward(self, x, gate_res = None): device, n, h, w = x.device, x.shape[1], self.heads, self.window res, gate = x.chunk(2, dim = -1) @@ -198,6 +235,9 @@ def forward(self, x, **kwargs): gate = rearrange(gate, 'b w n d -> b (w n) d') gate = gate[:, :n] + if exists(gate_res): + gate = gate + gate_res + return self.act(gate) * res class AxiallyFold(nn.Module): @@ -237,15 +277,18 @@ def __init__( act = nn.Identity() ): super().__init__() - SGU = partial(CausalLocalSGU, window = window) if exists(window) and window < seq_len else CausalSGU + is_windowed = exists(window) and window < seq_len + + SGU_klass = partial(CausalLocalSGU, window = window) if is_windowed else CausalSGU + Attention_klass = partial(LocalAttention, window = window) if is_windowed else Attention - self.attn = Attention(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None + self.attn = Attention_klass(dim_in = dim, dim_inner = attn_dim, dim_out = dim_ff // 2) if exists(attn_dim) else None self.proj_in = nn.Sequential( nn.Linear(dim, dim_ff), nn.GELU() ) - self.sgu = SGU(dim_ff, seq_len, causal, heads = heads, act = act) + self.sgu = SGU_klass(dim_ff, seq_len, causal, heads = heads, act = act) self.proj_out = nn.Linear(dim_ff // 2, dim) def forward(self, x): diff --git a/setup.py b/setup.py index f95b719..ac5bdb9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'g-mlp-gpt', packages = find_packages(), - version = '0.0.14', + version = '0.0.15', license='MIT', description = 'gMLP - GPT', author = 'Phil Wang', diff --git a/train.py b/train.py index be6ec24..7d77bfd 100644 --- a/train.py +++ b/train.py @@ -41,7 +41,8 @@ def decode_tokens(tokens): dim = 512, seq_len = SEQ_LEN, depth = 8, - window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN) + window = (16, 32, 64, 128, 256, 512, 768, SEQ_LEN), + attn_dim = 16 ) model = AutoregressiveWrapper(model)