Skip to content

Commit

Permalink
Position embeddings inherit nn.Module
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 30, 2023
1 parent 6f51b79 commit f0c52d1
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions supar/modules/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,30 +609,28 @@ def forward(self, x):
return x


class PositionalEmbedding(nn.Module):
class PositionalEmbedding(nn.Embedding):

def __init__(
self,
n_model: int = 1024,
max_len: int = 1024
) -> PositionalEmbedding:
super().__init__()

self.embed = nn.Embedding(max_len, n_model)
super().__init__(max_len, n_model)

self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
w = self.embed.weight
w = self.weight
max_len, n_model = w.shape
w = w.new_tensor(range(max_len)).unsqueeze(-1)
w = w / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos()
self.embed.weight.copy_(w)
self.weight.copy_(w)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.embed(x.new_tensor(range(x.shape[1])).long())
return torch.embedding(self.weight, x.new_tensor(range(x.shape[1]), dtype=torch.long))


class RelativePositionalEmbedding(nn.Module):
Expand All @@ -642,24 +640,23 @@ def __init__(
n_model: int = 1024,
max_len: int = 1024
) -> RelativePositionalEmbedding:
super().__init__()

self.embed = nn.Embedding(max_len, n_model)
super().__init__(max_len, n_model)

self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
w = self.embed.weight
w = self.weight
max_len, n_model = w.shape
pos = torch.cat((w.new_tensor(range(-max_len//2, 0)), w.new_tensor(range(max_len//2))))
w = pos.unsqueeze(-1) / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
w[:, 0::2], w[:, 1::2] = w[:, 0::2].sin(), w[:, 1::2].cos()
self.embed.weight.copy_(w)
self.weight.copy_(w)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
offset = sum(divmod(self.embed.weight.shape[0], 2))
return self.embed((k.new_tensor(range(k.shape[0])) - q.new_tensor(range(q.shape[0])).unsqueeze(-1)).long() + offset)
indices = sum(divmod(self.weight.shape[0], 2))
indices = (k.new_tensor(range(k.shape[0])) - q.new_tensor(range(q.shape[0])).unsqueeze(-1)).long() + indices
return torch.embedding(self.weight, indices)


class SinusoidPositionalEmbedding(nn.Module):
Expand All @@ -683,31 +680,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return pos


class RotaryPositionalEmbedding(nn.Module):
class RotaryPositionalEmbedding(nn.Embedding):

def __init__(
self,
n_model: int = 1024,
max_len: int = 1024
) -> RotaryPositionalEmbedding:
super().__init__()

self.embed = nn.Embedding(max_len, n_model)
super().__init__(max_len, n_model)

self.reset_parameters()

@torch.no_grad()
def reset_parameters(self):
w = self.embed.weight
w = self.weight
max_len, n_model = w.shape
pos = w.new_tensor(range(max_len)).unsqueeze(-1)
w = pos / 10000 ** (w.new_tensor(range(n_model)).div(2, rounding_mode='floor') * 2 / n_model)
sin, cos = w[:, 0::2].sin(), w[:, 1::2].cos()
w[:, :sin.shape[1]], w[:, sin.shape[1]:] = sin, cos
self.embed.weight.copy_(w)
self.weight.copy_(w)

def forward(self, x: torch.Tensor) -> torch.Tensor:
pos = self.embed(x.new_tensor(range(x.shape[0]), dtype=torch.long)).unsqueeze(1)
pos = torch.embedding(self.weight, x.new_tensor(range(x.shape[0]), dtype=torch.long)).unsqueeze(1)
sin, cos = pos.chunk(2, -1)
sin = torch.stack((sin, sin), -1).view_as(pos)
cos = torch.stack((cos, cos), -1).view_as(pos)
Expand Down

0 comments on commit f0c52d1

Please sign in to comment.