1+ import einops
2+ import torch
3+ import torch .nn as nn
4+
5+ class Fastformer (nn .Module ):
6+ def __init__ (self , dim = 3 , decode_dim = 16 ):
7+ super (Fastformer , self ).__init__ ()
8+ # Generate weight for Wquery、Wkey and Wvalue
9+ self .to_qkv = nn .Linear (dim , decode_dim * 3 , bias = False )
10+ self .weight_q = nn .Linear (dim , decode_dim , bias = False )
11+ self .weight_k = nn .Linear (dim , decode_dim , bias = False )
12+ self .weight_v = nn .Linear (dim , decode_dim , bias = False )
13+ self .weight_r = nn .Linear (decode_dim , decode_dim , bias = False )
14+ self .scale_factor = decode_dim ** - 0.5
15+
16+ def forward (self , x , mask = None ):
17+ query = self .weight_q (x )
18+ key = self .weight_k (x )
19+ value = self .weight_v (x )
20+ b , n , d = query .shape
21+
22+ # Caculate the global query
23+ alpha_weight = torch .softmax (query * self .scale_factor , dim = - 1 )
24+ global_query = query * alpha_weight
25+ global_query = torch .einsum ('b n d -> b d' , global_query )
26+
27+ # Model the interaction between global query vector and the key vector
28+ repeat_global_query = einops .repeat (global_query , 'b d -> b copy d' , copy = n )
29+ p = repeat_global_query * key
30+ beta_weight = torch .softmax (key * self .scale_factor , dim = - 1 )
31+ global_key = key * beta_weight
32+ global_key = torch .einsum ('b n d -> b d' , global_key )
33+
34+ # key-value
35+ key_value_interaction = torch .einsum ('b j, b n j -> b n j' , global_key , value )
36+ key_value_interaction_out = self .weight_r (key_value_interaction )
37+ result = key_value_interaction_out + query
38+ return result
39+
40+ if __name__ == '__main__' :
41+ model = Fastformer (dim = 3 , decode_dim = 8 )
42+ x = torch .randn (4 , 6 , 3 )
43+ result = model (x )
44+ print (result .size ())
0 commit comments