-
Notifications
You must be signed in to change notification settings - Fork 1
/
encoder_with_elmo.py
152 lines (113 loc) · 4.5 KB
/
encoder_with_elmo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import sys
sys.path.insert(0, '../')
import torch
from torch import nn
from torch.autograd import Variable
from holder import *
from util import *
from elmo_encoder import *
from elmo_loader import *
class EncoderWithElmo(torch.nn.Module):
def __init__(self, opt, shared):
super(EncoderWithElmo, self).__init__()
# bookkeeping
self.opt = opt
self.shared = shared
self.elmo_drop = nn.Dropout(opt.elmo_dropout)
self.drop = LockedDropout(opt.dropout)
if opt.dynamic_elmo == 1:
self.elmo = ElmoEncoder(opt, shared)
else:
self.elmo = ElmoLoader(opt, shared)
# rnn merger
bidir = opt.birnn == 1
rnn_in_size = opt.word_vec_size + opt.elmo_in_size if opt.use_elmo_only == 0 else opt.elmo_in_size
rnn_hidden_size = opt.hidden_size if not bidir else opt.hidden_size//2
self.rnn = build_rnn(
opt.rnn_type,
input_size=rnn_in_size,
hidden_size=rnn_hidden_size,
num_layers=opt.rnn_layer,
bias=True,
batch_first=True,
dropout=opt.dropout,
bidirectional=bidir)
if opt.elmo_blend == 'concat':
self.sampler_pre = nn.Linear(opt.elmo_in_size*3, opt.elmo_in_size)
self.sampler_post = nn.Linear(opt.elmo_in_size*3, opt.elmo_in_size)
elif opt.elmo_blend == 'interpolate':
self.gamma_pre = nn.Parameter(torch.ones(1), requires_grad=True)
self.gamma_pre.skip_init = 1
self.gamma_post = nn.Parameter(torch.ones(1), requires_grad=True)
self.gamma_post.skip_init = 1
self.w_pre = nn.Parameter(torch.ones(3), requires_grad=True)
self.w_pre.skip_init = 1
self.w_post = nn.Parameter(torch.ones(3), requires_grad=True)
self.w_post.skip_init = 1
self.softmax = nn.Softmax(0)
def rnn_over(self, x):
x = self.drop(x)
x, h = self.rnn(x)
return x, h
def interpolate_elmo(self, elmo_layers1, elmo_layers2, w, gamma):
weights = nn.Softmax(0)(w)
# interpolate
if self.opt.elmo_layer == 3:
sent1 = elmo_layers1[0] * weights[0] + elmo_layers1[1] * weights[1] + elmo_layers1[2] * weights[2]
sent2 = elmo_layers2[0] * weights[0] + elmo_layers2[1] * weights[1] + elmo_layers2[2] * weights[2]
elif self.opt.elmo_layer == 2:
sent1 = elmo_layers1[0] * weights[0] + elmo_layers1[1] * weights[1]
sent2 = elmo_layers2[0] * weights[0] + elmo_layers2[1] * weights[1]
elif self.opt.elmo_layer == 1:
sent1 = elmo_layers1[0] * weights[0]
sent2 = elmo_layers2[0] * weights[0]
return sent1*gamma, sent2*gamma
def concat_elmo(self, elmo_layers1, elmo_layers2):
return torch.cat(elmo_layers1, 2), torch.cat(elmo_layers2, 2)
def sample_elmo(self, sampler, elmo1, elmo2):
elmo1 = sampler(elmo1.view(-1, self.opt.elmo_in_size*3)).view(self.shared.batch_l, self.shared.sent_l1, -1)
elmo2 = sampler(elmo2.view(-1, self.opt.elmo_in_size*3)).view(self.shared.batch_l, self.shared.sent_l2, -1)
return elmo1, elmo2
def forward(self, sent1, sent2):
# elmo pass
elmo1, elmo2 = self.elmo()
# pre-rnn elmo
elmo_pre1, elmo_pre2 = None, None
if self.opt.elmo_blend == 'interpolate':
elmo_pre1, elmo_pre2 = self.interpolate_elmo(elmo1, elmo2, self.w_pre, self.gamma_pre)
elif self.opt.elmo_blend == 'concat':
elmo_pre1, elmo_pre2 = self.concat_elmo(elmo1, elmo2)
elmo_pre1, elmo_pre2 = self.sample_elmo(self.sampler_pre, elmo_pre1, elmo_pre2)
elmo_pre1, elmo_pre2 = self.elmo_drop(elmo_pre1), self.elmo_drop(elmo_pre2)
enc1, enc2 = elmo_pre1, elmo_pre2
if self.opt.use_elmo_only == 0:
enc1 = torch.cat([sent1, enc1], 2)
enc2 = torch.cat([sent2, enc2], 2)
# read
enc1, _ = self.rnn_over(enc1)
enc2, _ = self.rnn_over(enc2)
# post-rnn elmo
if self.opt.use_elmo_post == 1:
elmo_post1, elmo_post2 = None, None
if self.opt.elmo_blend == 'interpolate':
elmo_post1, elmo_post2 = self.interpolate_elmo(elmo1, elmo2, self.w_post, self.gamma_post)
elif self.opt.elmo_blend == 'concat':
elmo_post1, elmo_post2 = self.concat_elmo(elmo1, elmo2)
elmo_post1, elmo_post2 = self.sample_elmo(self.sampler_post, elmo_post1, elmo_post2)
elmo_post1, elmo_post2 = self.elmo_drop(elmo_post1), self.elmo_drop(elmo_post2)
enc1 = torch.cat([enc1, elmo_post1], 2)
enc2 = torch.cat([enc2, elmo_post2], 2)
# record
# take lstm encoding as embeddings for classification
# take post-lstm encoding as encodings for attention
self.shared.input_emb1 = enc1
self.shared.input_emb2 = enc2
self.shared.input_enc1 = enc1
self.shared.input_enc2 = enc2
return [self.shared.input_emb1, self.shared.input_emb2, self.shared.input_enc1, self.shared.input_enc2]
def begin_pass(self):
pass
def end_pass(self):
pass
if __name__ == '__main__':
pass