From 035f1255d5fda72a2ae1f8329a78a86e8a29d49c Mon Sep 17 00:00:00 2001 From: w4123 <1840686745@qq.com> Date: Sat, 27 Aug 2022 13:31:40 +0800 Subject: [PATCH] Fix --- attentions.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/attentions.py b/attentions.py index 0eeea464..38599385 100644 --- a/attentions.py +++ b/attentions.py @@ -147,11 +147,14 @@ def forward(self, x, c, attn_mask=None): def attention(self, query, key, value, mask=None): # reshape [b, d, t] -> [b, n_h, t, d_k] - b, d, t_s, t_t = (*key.size(), query.size(2)) + b = torch.tensor(key.size(0)) + d = torch.tensor(key.size(1)) + t_s = torch.tensor(key.size(2)) + t_t = torch.tensor(query.size(2)) query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) - + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) if self.window_size is not None: assert t_s == t_t, "Relative attention is only available for self-attention."