Skip to content

Commit

Permalink
Implement rotary transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 29, 2023
1 parent 308a7bb commit 6f51b79
Showing 1 changed file with 162 additions and 32 deletions.
194 changes: 162 additions & 32 deletions supar/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
return x


class RelativePositionTransformerEncoderLayer(nn.Module):
class RelativePositionTransformerEncoderLayer(TransformerEncoderLayer):

def __init__(
self,
Expand Down Expand Up @@ -212,16 +212,35 @@ def __init__(

self.pre_norm = pre_norm

def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
if self.pre_norm:
n = self.attn_norm(x)
x = x + self.dropout(self.attn(n, n, n, mask))
n = self.ffn_norm(x)
x = x + self.dropout(self.ffn(n))
else:
x = self.attn_norm(x + self.dropout(self.attn(x, x, x, mask)))
x = self.ffn_norm(x + self.dropout(self.ffn(x)))
return x

class RotaryPositionTransformerEncoderLayer(TransformerEncoderLayer):

def __init__(
self,
n_heads: int = 8,
n_model: int = 1024,
n_inner: int = 2048,
activation: str = 'relu',
pre_norm: bool = False,
attn_dropout: float = 0.1,
ffn_dropout: float = 0.1,
dropout: float = 0.1
) -> RotaryPositionTransformerEncoderLayer:
super(RotaryPositionTransformerEncoderLayer, self).__init__()

self.attn = RotaryPositionMultiHeadAttention(n_heads=n_heads,
n_model=n_model,
n_embed=n_model//n_heads,
dropout=attn_dropout)
self.attn_norm = nn.LayerNorm(n_model)
self.ffn = PositionwiseFeedForward(n_model=n_model,
n_inner=n_inner,
activation=activation,
dropout=ffn_dropout)
self.ffn_norm = nn.LayerNorm(n_model)
self.dropout = nn.Dropout(dropout)

self.pre_norm = pre_norm


class TransformerDecoderLayer(nn.Module):
Expand Down Expand Up @@ -283,7 +302,7 @@ def forward(
return x_tgt


class RelativePositionTransformerDecoderLayer(nn.Module):
class RelativePositionTransformerDecoderLayer(TransformerDecoderLayer):

def __init__(
self,
Expand Down Expand Up @@ -317,26 +336,40 @@ def __init__(

self.pre_norm = pre_norm

def forward(

class RotaryPositionTransformerDecoderLayer(TransformerDecoderLayer):

def __init__(
self,
x_tgt: torch.Tensor,
x_src: torch.Tensor,
tgt_mask: torch.BoolTensor,
src_mask: torch.BoolTensor,
attn_mask: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
if self.pre_norm:
n_tgt = self.self_attn_norm(x_tgt)
x_tgt = x_tgt + self.dropout(self.self_attn(n_tgt, n_tgt, n_tgt, tgt_mask, attn_mask))
n_tgt = self.mha_attn_norm(x_tgt)
x_tgt = x_tgt + self.dropout(self.mha_attn(n_tgt, x_src, x_src, src_mask))
n_tgt = self.ffn_norm(x_tgt)
x_tgt = x_tgt + self.dropout(self.ffn(x_tgt))
else:
x_tgt = self.self_attn_norm(x_tgt + self.dropout(self.self_attn(x_tgt, x_tgt, x_tgt, tgt_mask, attn_mask)))
x_tgt = self.mha_attn_norm(x_tgt + self.dropout(self.mha_attn(x_tgt, x_src, x_src, src_mask)))
x_tgt = self.ffn_norm(x_tgt + self.dropout(self.ffn(x_tgt)))
return x_tgt
n_heads: int = 8,
n_model: int = 1024,
n_inner: int = 2048,
activation: str = 'relu',
pre_norm: bool = False,
attn_dropout: float = 0.1,
ffn_dropout: float = 0.1,
dropout: float = 0.1
) -> RotaryPositionTransformerDecoderLayer:
super(RotaryPositionTransformerDecoderLayer, self).__init__()

self.self_attn = RotaryPositionMultiHeadAttention(n_heads=n_heads,
n_model=n_model,
n_embed=n_model//n_heads,
dropout=attn_dropout)
self.self_attn_norm = nn.LayerNorm(n_model)
self.mha_attn = RotaryPositionMultiHeadAttention(n_heads=n_heads,
n_model=n_model,
n_embed=n_model//n_heads,
dropout=attn_dropout)
self.mha_attn_norm = nn.LayerNorm(n_model)
self.ffn = PositionwiseFeedForward(n_model=n_model,
n_inner=n_inner,
activation=activation,
dropout=ffn_dropout)
self.ffn_norm = nn.LayerNorm(n_model)
self.dropout = nn.Dropout(dropout)

self.pre_norm = pre_norm


class MultiHeadAttention(nn.Module):
Expand Down Expand Up @@ -386,7 +419,6 @@ def forward(
batch_size, _ = mask.shape
# [seq_len, batch_size * n_heads, n_embed]
q = self.wq(q).view(-1, batch_size * self.n_heads, self.n_embed)
# [src_len, batch_size * n_heads, n_embed]
k = self.wk(k).view(-1, batch_size * self.n_heads, self.n_embed)
v = self.wv(v).view(-1, batch_size * self.n_heads, self.n_embed)

Expand Down Expand Up @@ -478,6 +510,72 @@ def forward(
return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x


class RotaryPositionMultiHeadAttention(nn.Module):

def __init__(
self,
n_heads: int = 8,
n_model: int = 1024,
n_embed: int = 128,
dropout: float = 0.1,
bias: bool = True,
attn: bool = False
) -> RotaryPositionMultiHeadAttention:
super(RotaryPositionMultiHeadAttention, self).__init__()

self.n_heads = n_heads
self.n_model = n_model
self.n_embed = n_embed
self.scale = n_embed**0.5

self.pos_embed = RotaryPositionalEmbedding(n_model=n_embed)
self.wq = nn.Linear(n_model, n_heads * n_embed, bias=bias)
self.wk = nn.Linear(n_model, n_heads * n_embed, bias=bias)
self.wv = nn.Linear(n_model, n_heads * n_embed, bias=bias)
self.wo = nn.Linear(n_heads * n_embed, n_model, bias=bias)
self.dropout = nn.Dropout(dropout)

self.attn = attn

self.reset_parameters()

def reset_parameters(self):
# borrowed from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/multihead_attention.py
nn.init.xavier_uniform_(self.wq.weight, 2 ** -0.5)
nn.init.xavier_uniform_(self.wk.weight, 2 ** -0.5)
nn.init.xavier_uniform_(self.wv.weight, 2 ** -0.5)
nn.init.xavier_uniform_(self.wo.weight)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: torch.BoolTensor,
attn_mask: Optional[torch.BoolTensor] = None
) -> torch.Tensor:
batch_size, _ = mask.shape
# [seq_len, batch_size * n_heads, n_embed]
q = self.pos_embed(self.wq(q).view(-1, batch_size * self.n_heads, self.n_embed))
k = self.pos_embed(self.wk(k).view(-1, batch_size * self.n_heads, self.n_embed))
v = self.wv(v).view(-1, batch_size * self.n_heads, self.n_embed)

mask = mask.unsqueeze(1).repeat(1, self.n_heads, 1).view(-1, 1, *mask.shape[1:])
# [batch_size * n_heads, seq_len, src_len]
if attn_mask is not None:
mask = mask & attn_mask
# [batch_size * n_heads, seq_len, src_len]
attn = torch.bmm(q.transpose(0, 1) / self.scale, k.movedim((0, 1), (2, 0)))
attn = torch.softmax(attn + torch.where(mask, 0., float('-inf')), -1)
attn = self.dropout(attn)
# [seq_len, batch_size * n_heads, n_embed]
x = torch.bmm(attn, v.transpose(0, 1)).transpose(0, 1)
# [seq_len, batch_size, n_model]
x = self.wo(x.reshape(-1, batch_size, self.n_heads * self.n_embed))

return (x, attn.view(batch_size, self.n_heads, *attn.shape[1:])) if self.attn else x


class PositionwiseFeedForward(nn.Module):

def __init__(
Expand Down Expand Up @@ -583,3 +681,35 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
pos = pos / 10000 ** (x.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
pos[..., 0::2], pos[..., 1::2] = pos[..., 0::2].sin(), pos[..., 1::2].cos()
return pos


class RotaryPositionalEmbedding(nn.Module):

def __init__(
self,
n_model: int = 1024,
max_len: int = 1024
) -> RotaryPositionalEmbedding:
super().__init__()

self.embed = nn.Embedding(max_len, n_model)

self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
w = self.embed.weight
max_len, n_model = w.shape
pos = w.new_tensor(range(max_len)).unsqueeze(-1)
w = pos / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
sin, cos = w[:, 0::2].sin(), w[:, 1::2].cos()
w[:, :sin.shape[1]], w[:, sin.shape[1]:] = sin, cos
self.embed.weight.copy_(w)

def forward(self, x: torch.Tensor) -> torch.Tensor:
pos = self.embed(x.new_tensor(range(x.shape[0]), dtype=torch.long)).unsqueeze(1)
sin, cos = pos.chunk(2, -1)
sin = torch.stack((sin, sin), -1).view_as(pos)
cos = torch.stack((cos, cos), -1).view_as(pos)
x = x * cos + torch.stack((-x[..., 1::2], x[..., ::2]), -1).view_as(x) * sin
return x

0 comments on commit 6f51b79

Please sign in to comment.