@@ -11,6 +11,8 @@ def __init__(self, dim = 3, decode_dim = 16):
1111 self .weight_k = nn .Linear (dim , decode_dim , bias = False )
1212 self .weight_v = nn .Linear (dim , decode_dim , bias = False )
1313 self .weight_r = nn .Linear (decode_dim , decode_dim , bias = False )
14+ self .weight_alpha = nn .Parameter (torch .randn (decode_dim ))
15+ self .weight_beta = nn .Parameter (torch .randn (decode_dim ))
1416 self .scale_factor = decode_dim ** - 0.5
1517
1618 def forward (self , x , mask = None ):
@@ -20,14 +22,14 @@ def forward(self, x, mask = None):
2022 b , n , d = query .shape
2123
2224 # Caculate the global query
23- alpha_weight = torch .softmax (query * self .scale_factor , dim = - 1 )
25+ alpha_weight = torch .softmax (torch . mul ( query , self . weight_alpha ) * self .scale_factor , dim = - 1 )
2426 global_query = query * alpha_weight
2527 global_query = torch .einsum ('b n d -> b d' , global_query )
2628
2729 # Model the interaction between global query vector and the key vector
2830 repeat_global_query = einops .repeat (global_query , 'b d -> b copy d' , copy = n )
2931 p = repeat_global_query * key
30- beta_weight = torch .softmax (key * self .scale_factor , dim = - 1 )
32+ beta_weight = torch .softmax (torch . mul ( key , self . weight_beta ) * self .scale_factor , dim = - 1 )
3133 global_key = key * beta_weight
3234 global_key = torch .einsum ('b n d -> b d' , global_key )
3335
0 commit comments