Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
w4123 committed Aug 27, 2022
1 parent 79c3fb2 commit 035f125
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,14 @@ def forward(self, x, c, attn_mask=None):

def attention(self, query, key, value, mask=None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, t_t = (*key.size(), query.size(2))
b = torch.tensor(key.size(0))
d = torch.tensor(key.size(1))
t_s = torch.tensor(key.size(2))
t_t = torch.tensor(query.size(2))
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)

scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert t_s == t_t, "Relative attention is only available for self-attention."
Expand Down

0 comments on commit 035f125

Please sign in to comment.