Skip to content

Commit

Permalink
RelPosSelfAttention rel shift
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 18, 2022
1 parent 5f935b2 commit 6c791b1
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor:
pos_emb, pos_emb_spatial_dim = relative_positional_encoding(axis, self.in_dim)
pos_emb = self.linear_pos(pos_emb)
pos_emb = nn.split_dims(pos_emb, axis=self.key_dim_total, dims=(self.num_heads, self.key_dim_per_head))
# pos_emb: (batch, head, 2*time1-1, d_k)
# pos_emb: (head, 2*time1-1, d_k)

q, k, v = self.forward_qkv(source)
hist_dim = nn.SpatialDim(f"{axis.description}:kv")
Expand All @@ -201,7 +201,7 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor:
# compute matrix b and matrix d
# (batch, head, time1, 2*time1-1)
matrix_bd = nn.dot(q_with_bias_v, pos_emb, reduce=self.key_dim_per_head)
matrix_bd = self.rel_shift(matrix_bd) # TODO
matrix_bd = self._rel_shift(matrix_bd, axis, pos_emb_spatial_dim, hist_dim)

scores = matrix_ac + matrix_bd # (batch, head, time1, time2)
scores *= self.key_dim_per_head.dimension ** -0.5
Expand All @@ -214,6 +214,26 @@ def __call__(self, source: nn.Tensor, *, axis: nn.Dim, **_kwargs) -> nn.Tensor:
output = self.proj(output)
return output

@classmethod
def _rel_shift(cls, x: nn.Tensor, axis: nn.Dim, pos_emb_spatial_dim: nn.Dim, hist_dim: nn.Dim) -> nn.Tensor:
"""
:param x: [B,H,T,T*2-1]
:param axis: T
:param pos_emb_spatial_dim: T*2-1
:param hist_dim: T' (equal to T but separate dim)
:return: [B,H,T,T']
"""
batch_dims = x.batch_dims_ordered((axis, pos_emb_spatial_dim))
x_padded = nn.pad(x, axes=pos_emb_spatial_dim, padding=(1, 0), value=0.) # [B,H,T,T*2]
pos_emb_spatial_dim_ = 1 + pos_emb_spatial_dim

x_padded = nn.reshape(x_padded, (axis, pos_emb_spatial_dim_), (pos_emb_spatial_dim_, axis)) # [B,H,T*2,T]
x_padded, pos_emb_spatial_dim_ = nn.slice(x_padded, axis=pos_emb_spatial_dim_, slice_start=1) # [B,H,T*2-1,T]
x_padded = nn.reshape(x_padded, (pos_emb_spatial_dim_, axis), (axis, pos_emb_spatial_dim_)) # [B,H,T,T*2-1]
x_padded, _ = nn.slice_nd(x_padded, axis=pos_emb_spatial_dim_, size=hist_dim) # [B,H,T,T']
x_padded.verify_out_shape(set(batch_dims) | {axis, hist_dim})
return x_padded


_relative_positional_encoding_cache = weakref.WeakKeyDictionary() # root name ctx -> (spatial_dim, feat_dim) -> enc

Expand Down

0 comments on commit 6c791b1

Please sign in to comment.