Skip to content

Commit c575b80

Browse files
Fix the error about the global weight
1 parent 863bdd1 commit c575b80

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

Fastformer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ def __init__(self, dim = 3, decode_dim = 16):
1111
self.weight_k = nn.Linear(dim, decode_dim, bias = False)
1212
self.weight_v = nn.Linear(dim, decode_dim, bias = False)
1313
self.weight_r = nn.Linear(decode_dim, decode_dim, bias = False)
14+
self.weight_alpha = nn.Parameter(torch.randn(decode_dim))
15+
self.weight_beta = nn.Parameter(torch.randn(decode_dim))
1416
self.scale_factor = decode_dim ** -0.5
1517

1618
def forward(self, x, mask = None):
@@ -20,14 +22,14 @@ def forward(self, x, mask = None):
2022
b, n, d = query.shape
2123

2224
# Caculate the global query
23-
alpha_weight = torch.softmax(query * self.scale_factor, dim = -1)
25+
alpha_weight = torch.softmax(torch.mul(query, self.weight_alpha) * self.scale_factor, dim = -1)
2426
global_query = query * alpha_weight
2527
global_query = torch.einsum('b n d -> b d', global_query)
2628

2729
# Model the interaction between global query vector and the key vector
2830
repeat_global_query = einops.repeat(global_query, 'b d -> b copy d', copy = n)
2931
p = repeat_global_query * key
30-
beta_weight = torch.softmax(key * self.scale_factor, dim = -1)
32+
beta_weight = torch.softmax(torch.mul(key, self.weight_beta) * self.scale_factor, dim = -1)
3133
global_key = key * beta_weight
3234
global_key = torch.einsum('b n d -> b d', global_key)
3335

0 commit comments

Comments
 (0)