Skip to content

Commit

Permalink
allow the input to go under self attention as well with cheap linear …
Browse files Browse the repository at this point in the history
…attention
  • Loading branch information
lucidrains committed Mar 22, 2021
1 parent 83fd55d commit 10a83b0
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
45 changes: 43 additions & 2 deletions perceiver_pytorch/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,44 @@

from perceiver_pytorch.perceiver_pytorch import exists, default, cache_fn, fourier_encode, RMSNorm, PreNorm, FeedForward, Attention

# linear attention

class LinearAttention(nn.Module):
def __init__(
self,
dim,
*,
heads = 4,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = heads * dim_head
self.heads = heads
self.scale = dim_head ** -0.5

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, mask = None):
h = self.heads
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))

q *= self.scale
q, k = q.softmax(dim = -1), k.softmax(dim = -2)

if exists(mask):
k.masked_fill_(mask, 0.)

context = einsum('b n d, b n e -> b d e', q, k)
out = einsum('b d e, b n d -> b n e', context, v)
out = rearrange(out, ' (b h) n d -> b n (h d)', h = h)
return self.to_out(out)

# main class

class Perceiver(nn.Module):
Expand Down Expand Up @@ -49,14 +87,15 @@ def __init__(
get_cross_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, input_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = input_dim)
get_cross_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

get_input_attn = lambda: PreNorm(input_dim, LinearAttention(input_dim, dropout = attn_dropout))
get_rev_cross_attn = lambda: PreNorm(input_dim, Attention(input_dim, latent_dim, heads = cross_heads, dim_head = cross_dim_head, dropout = attn_dropout), context_dim = latent_dim)
get_rev_cross_ff = lambda: PreNorm(input_dim, FeedForward(input_dim, dropout = ff_dropout))

get_latent_attn = lambda: PreNorm(latent_dim, Attention(latent_dim, heads = latent_heads, dim_head = latent_dim_head, dropout = attn_dropout))
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim, dropout = ff_dropout))

if weight_tie_layers:
get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_latent_attn, get_latent_ff))
get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff = map(cache_fn, (get_cross_attn, get_cross_ff, get_rev_cross_attn, get_rev_cross_ff, get_input_attn, get_latent_attn, get_latent_ff))

self.layers = nn.ModuleList([])
for _ in range(depth):
Expand All @@ -65,6 +104,7 @@ def __init__(
get_cross_ff(),
get_rev_cross_attn(),
get_rev_cross_ff(),
get_input_attn(),
get_latent_attn(),
get_latent_ff()
]))
Expand Down Expand Up @@ -96,13 +136,14 @@ def forward(self, data, mask = None):
x = self.latents + self.pos_emb
x = repeat(x, 'n d -> b n d', b = b)

for i, (cross_attn, cross_ff, rev_cross_attn, rev_cross_ff, latent_attn, latent_ff) in enumerate(self.layers):
for i, (cross_attn, cross_ff, rev_cross_attn, rev_cross_ff, input_attn, latent_attn, latent_ff) in enumerate(self.layers):
is_last = i == (len(self.layers) - 1)

x = cross_attn(x, context = data, mask = mask) + x
x = cross_ff(x) + x

if not is_last:
data = input_attn(data, mask = mask) + data
data = rev_cross_attn(data, context = x) + data
data = rev_cross_ff(data) + data

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'perceiver-pytorch',
packages = find_packages(),
version = '0.1.9',
version = '0.1.10',
license='MIT',
description = 'Perceiver - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 10a83b0

Please sign in to comment.