Skip to content

Commit

Permalink
Merge pull request #2544 from Zth9730/fix_attention
Browse files Browse the repository at this point in the history
[s2t] fix attention eval bug, do not compose kv in infer
  • Loading branch information
zh794390558 authored Oct 18, 2022
2 parents 899236b + 1ea828c commit eac545e
Showing 1 changed file with 2 additions and 18 deletions.
20 changes: 2 additions & 18 deletions paddlespeech/s2t/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

import paddle
from paddle import nn
from paddle.nn import functional as F
from paddle.nn import initializer as I

from paddlespeech.s2t.modules.align import Linear
Expand Down Expand Up @@ -56,16 +55,6 @@ def __init__(self, n_head: int, n_feat: int, dropout_rate: float):
self.linear_out = Linear(n_feat, n_feat)
self.dropout = nn.Dropout(p=dropout_rate)

def _build_once(self, *args, **kwargs):
super()._build_once(*args, **kwargs)
# if self.self_att:
# self.linear_kv = Linear(self.n_feat, self.n_feat*2)
if not self.training:
self.weight = paddle.concat(
[self.linear_k.weight, self.linear_v.weight], axis=-1)
self.bias = paddle.concat([self.linear_k.bias, self.linear_v.bias])
self._built = True

def forward_qkv(self,
query: paddle.Tensor,
key: paddle.Tensor,
Expand All @@ -87,13 +76,8 @@ def forward_qkv(self,
n_batch = query.shape[0]

q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
if self.training:
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
else:
k, v = F.linear(key, self.weight, self.bias).view(
n_batch, -1, 2 * self.h, self.d_k).split(
2, axis=2)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)

q = q.transpose([0, 2, 1, 3]) # (batch, head, time1, d_k)
k = k.transpose([0, 2, 1, 3]) # (batch, head, time2, d_k)
Expand Down

0 comments on commit eac545e

Please sign in to comment.