-
Notifications
You must be signed in to change notification settings - Fork 525
/
tc_lstm.py
42 lines (36 loc) · 1.74 KB
/
tc_lstm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
# file: tc_lstm.py
# author: songyouwei <youwei0314@gmail.com>, ZhangYikai <zykhelloha@gmail.com>
# Copyright (C) 2018. All Rights Reserved.
from layers.dynamic_rnn import DynamicLSTM
import torch
import torch.nn as nn
class TC_LSTM(nn.Module):
def __init__(self, embedding_matrix, opt):
super(TC_LSTM, self).__init__()
self.embed = nn.Embedding.from_pretrained(torch.tensor(embedding_matrix, dtype=torch.float))
self.lstm_l = DynamicLSTM(opt.embed_dim * 2, opt.hidden_dim, num_layers=1, batch_first=True)
self.lstm_r = DynamicLSTM(opt.embed_dim * 2, opt.hidden_dim, num_layers=1, batch_first=True)
self.dense = nn.Linear(opt.hidden_dim*2, opt.polarities_dim)
def forward(self, inputs):
# Get the target and its length(target_len)
x_l, x_r, target = inputs[0], inputs[1], inputs[2]
x_l_len, x_r_len = torch.sum(x_l != 0, dim=-1), torch.sum(x_r != 0, dim=-1)
target_len = torch.sum(target != 0, dim=-1, dtype=torch.float)[:, None, None]
x_l, x_r, target = self.embed(x_l), self.embed(x_r), self.embed(target)
v_target = torch.div(target.sum(dim=1, keepdim=True),
target_len) # v_{target} in paper: average the target words
# the concatenation of word embedding and target vector v_{target}:
x_l = torch.cat(
(x_l, torch.cat(([v_target] * x_l.shape[1]), 1)),
2
)
x_r = torch.cat(
(x_r, torch.cat(([v_target] * x_r.shape[1]), 1)),
2
)
_, (h_n_l, _) = self.lstm_l(x_l, x_l_len)
_, (h_n_r, _) = self.lstm_r(x_r, x_r_len)
h_n = torch.cat((h_n_l[0], h_n_r[0]), dim=-1)
out = self.dense(h_n)
return out