Skip to content

Commit

Permalink
allow for customizable gating activation in SGU, fix bug with padding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 23, 2021
1 parent bd9b68a commit 18efacd
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ model = gMLPGPT(
num_tokens = 20000,
dim = 512,
seq_len = 16384,
reversible = True,
reversible = True, # reversible networks
act = nn.Tanh(), # tanh activation for spatial gating
depth = 12,
window = (
128,
Expand Down
21 changes: 13 additions & 8 deletions g_mlp_gpt/g_mlp_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def __init__(
dim,
dim_seq,
init_eps = 1e-3,
heads = 4
heads = 4,
act = nn.Identity()
):
super().__init__()
dim_out = dim // 2
Expand All @@ -101,6 +102,7 @@ def __init__(
nn.init.uniform_(self.weight, -init_eps, init_eps)
nn.init.constant_(self.bias, 1.)

self.act = act
self.register_buffer('mask', ~torch.ones(dim_seq, dim_seq).triu_(1).bool())

def forward(self, x):
Expand All @@ -119,7 +121,7 @@ def forward(self, x):
gate = gate + rearrange(bias, 'h n -> () h n ()')
gate = rearrange(gate, 'b h n d -> b n (h d)')

return gate * res
return self.act(gate) * res

class CausalLocalSGU(nn.Module):
def __init__(
Expand All @@ -128,7 +130,8 @@ def __init__(
dim_seq,
init_eps = 1e-3,
heads = 4,
window = 128
window = 128,
act = nn.Identity()
):
super().__init__()
dim_out = dim // 2
Expand Down Expand Up @@ -185,7 +188,7 @@ def forward(self, x):
return self.fn(x)

n = x.shape[1]
x = pad_to_multiple(x, self.every, dim = 1)
x = pad_to_multiple(x, self.every, dim = -2)
x = rearrange(x, 'b (n e) d -> (b e) n d', e = every)
x = self.fn(x)

Expand All @@ -202,14 +205,15 @@ def gMLPBlock(
dim_ff,
heads = 4,
causal = False,
window = None
window = None,
act = nn.Identity()
):
SGU = partial(CausalLocalSGU, window = window) if exists(window) and window < seq_len else CausalSGU

return nn.Sequential(
nn.Linear(dim, dim_ff),
nn.GELU(),
SGU(dim_ff, seq_len, causal, heads = heads),
SGU(dim_ff, seq_len, causal, heads = heads, act = act),
nn.Linear(dim_ff // 2, dim)
)

Expand All @@ -227,7 +231,8 @@ def __init__(
ff_mult = 4,
prob_survival = 1.,
reversible = False,
window = None
window = None,
act = nn.Identity()
):
super().__init__()
dim_ff = dim * ff_mult
Expand All @@ -243,7 +248,7 @@ def __init__(
layers = nn.ModuleList([])

for ind, (w, ax) in zip(range(depth), window):
get_gmlp = lambda: PreNorm(dim, AxiallyFold(dim, ax, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, heads = heads, window = w)))
get_gmlp = lambda: PreNorm(dim, AxiallyFold(dim, ax, gMLPBlock(dim = dim, dim_ff = dim_ff, seq_len = seq_len, heads = heads, window = w, act = act)))

layer_blocks = nn.ModuleList([
get_gmlp()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'g-mlp-gpt',
packages = find_packages(),
version = '0.0.9',
version = '0.0.11',
license='MIT',
description = 'gMLP - GPT',
author = 'Phil Wang',
Expand Down

0 comments on commit 18efacd

Please sign in to comment.