-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlayers.py
125 lines (100 loc) · 4.65 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch.nn as nn
import torch
import math
import torch.nn.functional as F
class InputEmbeddings(nn.Module):
def __init__(self, vocab_size: int, d_model: int) -> None:
super().__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.embedding = nn.Embedding(vocab_size, d_model)
def forward(self, x):
return self.embedding(x) * math.sqrt(self.d_model)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length):
super().__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float) * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
self.num_heads = num_heads
self.d_model = d_model
self.head_dim = d_model // num_heads
self.query_linear = nn.Linear(d_model, d_model, bias=False)
self.key_linear = nn.Linear(d_model, d_model, bias=False)
self.value_linear = nn.Linear(d_model, d_model, bias=False)
self.output_linear = nn.Linear(d_model, d_model)
def split_heads(self, x, batch_size):
seq_length = x.size(1)
x = x.view(batch_size, seq_length, self.num_heads, self.head_dim)
return x.permute(0,2,1,3)
def compute_attenttion(self, query, key, value, mask=None):
scores = torch.matmul(query, key.transpose(-2,-1)) / (self.head_dim ** 0.5)
#print("shape de scores:", scores.shape)
if mask is not None:
# print("Shape de mask:", mask.shape)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
return torch.matmul(attention_weights, value)
def combine_heads(self, x, batch_size):
x = x.permute(0,2,1,3).contiguous()
return x.view(batch_size, -1, self.d_model)
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
query = self.split_heads(self.query_linear(query), batch_size)
key = self.split_heads(self.key_linear(key), batch_size)
value = self.split_heads(self.value_linear(value), batch_size)
attention_weigths = self.compute_attenttion(query, key, value, mask)
output = self.combine_heads(attention_weigths, batch_size)
return self.output_linear(output)
# Feed-forward sublayer in encoder layers
class FeedForwardSublayer(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
# Encoder layer
class EncoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.ff_sublayer = FeedForwardSublayer(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, src_mask):
attn_output = self.self_attn(x, x, x, src_mask)
x = self.norm1(x + self.dropout(attn_output))
ff_output = self.ff_sublayer(x)
x = self.norm2(x + self.dropout(ff_output))
return x
# Ahora viene el Decoder Layer
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, num_heads)
self.cross_attn = MultiHeadAttention(d_model, num_heads)
self.ff_sublayer = FeedForwardSublayer(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x, encoder_output, tgt_mask):
attn_output = self.self_attn(x,x,x, tgt_mask)
x = self.norm1(x + self.dropout(attn_output))
cross_attn_output = self.cross_attn(x, encoder_output, encoder_output)
x = self.norm2(x + self.dropout(cross_attn_output))
ff_output = self.ff_sublayer(x)
x = self.norm3(x + self.dropout(ff_output))
return x