-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathcustom_layers.py
70 lines (56 loc) · 2.58 KB
/
custom_layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import copy
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.autograd import Variable
class TransformerEncoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", batch_first=False):
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(self, src, src_mask=None, src_key_padding_mask=None):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Bool]) -> Tensor
src2, weights = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
average_attn_weights=False)
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src, weights
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
class TransformerEncoder(nn.Module):
__constants__ = ['norm']
def __init__(self, encoder_layer, num_layers, norm=None):
super(TransformerEncoder, self).__init__()
self.layers = _get_clones(encoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(self, src, mask=None, src_key_padding_mask=None, return_att=False):
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Bool]) -> Tensor
output = src
weights = []
for mod in self.layers:
output, weight = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)
weights.append(weight)
if self.norm is not None:
output = self.norm(output)
if return_att:
return output, weights
else:
return output