Skip to content

Commit ca97336

Browse files
Add base code for Fastformer
1 parent a54ab0e commit ca97336

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

Fastformer.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import einops
2+
import torch
3+
import torch.nn as nn
4+
5+
class Fastformer(nn.Module):
6+
def __init__(self, dim = 3, decode_dim = 16):
7+
super(Fastformer, self).__init__()
8+
# Generate weight for Wquery、Wkey and Wvalue
9+
self.to_qkv = nn.Linear(dim, decode_dim * 3, bias = False)
10+
self.weight_q = nn.Linear(dim, decode_dim, bias = False)
11+
self.weight_k = nn.Linear(dim, decode_dim, bias = False)
12+
self.weight_v = nn.Linear(dim, decode_dim, bias = False)
13+
self.weight_r = nn.Linear(decode_dim, decode_dim, bias = False)
14+
self.scale_factor = decode_dim ** -0.5
15+
16+
def forward(self, x, mask = None):
17+
query = self.weight_q(x)
18+
key = self.weight_k(x)
19+
value = self.weight_v(x)
20+
b, n, d = query.shape
21+
22+
# Caculate the global query
23+
alpha_weight = torch.softmax(query * self.scale_factor, dim = -1)
24+
global_query = query * alpha_weight
25+
global_query = torch.einsum('b n d -> b d', global_query)
26+
27+
# Model the interaction between global query vector and the key vector
28+
repeat_global_query = einops.repeat(global_query, 'b d -> b copy d', copy = n)
29+
p = repeat_global_query * key
30+
beta_weight = torch.softmax(key * self.scale_factor, dim = -1)
31+
global_key = key * beta_weight
32+
global_key = torch.einsum('b n d -> b d', global_key)
33+
34+
# key-value
35+
key_value_interaction = torch.einsum('b j, b n j -> b n j', global_key, value)
36+
key_value_interaction_out = self.weight_r(key_value_interaction)
37+
result = key_value_interaction_out + query
38+
return result
39+
40+
if __name__ == '__main__':
41+
model = Fastformer(dim = 3, decode_dim = 8)
42+
x = torch.randn(4, 6, 3)
43+
result = model(x)
44+
print(result.size())

0 commit comments

Comments
 (0)