diff --git a/nn/attention.py b/nn/attention.py index f3fb557b..1e323701 100644 --- a/nn/attention.py +++ b/nn/attention.py @@ -183,7 +183,7 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor: pos_emb, pos_emb_spatial_dim = relative_positional_encoding(axis, self.in_dim) pos_emb = self.linear_pos(pos_emb) pos_emb = nn.split_dims(pos_emb, axis=self.key_dim_total, dims=(self.num_heads, self.key_dim_per_head)) - # pos_emb: (batch, head, 2*time1-1, d_k) + # pos_emb: (head, 2*time1-1, d_k) q, k, v = self.forward_qkv(source) hist_dim = nn.SpatialDim(f"{axis.description}:kv") @@ -201,7 +201,7 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor: # compute matrix b and matrix d # (batch, head, time1, 2*time1-1) matrix_bd = nn.dot(q_with_bias_v, pos_emb, reduce=self.key_dim_per_head) - matrix_bd = self.rel_shift(matrix_bd) # TODO + matrix_bd = self._rel_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim) scores = matrix_ac + matrix_bd # (batch, head, time1, time2) scores *= self.key_dim_per_head.dimension ** -0.5 @@ -214,6 +214,26 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor: output = self.proj(output) return output + @classmethod + def _rel_shift(cls, x: nn.Tensor, axis: nn.Dim, pos_emb_spatial_dim: nn.Dim, hist_dim: nn.Dim) -> nn.Tensor: + """ + :param x: [B,H,T,T*2-1] + :param axis: T + :param pos_emb_spatial_dim: T*2-1 + :param hist_dim: T' (equal to T but separate dim) + :return: [B,H,T,T'] + """ + batch_dims = x.batch_dims_ordered((axis, pos_emb_spatial_dim)) + x_padded = nn.pad(x, axes=pos_emb_spatial_dim, padding=(1, 0), value=0.) # [B,H,T,T*2] + pos_emb_spatial_dim_ = 1 + pos_emb_spatial_dim + + x_padded = nn.reshape(x_padded, (axis, pos_emb_spatial_dim_), (pos_emb_spatial_dim_, axis)) # [B,H,T*2,T] + x_padded, pos_emb_spatial_dim_ = nn.slice(x_padded, axis=pos_emb_spatial_dim_, slice_start=1) # [B,H,T*2-1,T] + x_padded = nn.reshape(x_padded, (pos_emb_spatial_dim_, axis), (axis, pos_emb_spatial_dim_)) # [B,H,T,T*2-1] + x_padded, _ = nn.slice_nd(x_padded, axis=pos_emb_spatial_dim_, size=hist_dim) # [B,H,T,T'] + x_padded.verify_out_shape(set(batch_dims) | {axis, hist_dim}) + return x_padded + _relative_positional_encoding_cache = weakref.WeakKeyDictionary() # root name ctx -> (spatial_dim, feat_dim) -> enc