-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathmodules.py
More file actions
85 lines (68 loc) · 3.02 KB
/
modules.py
File metadata and controls
85 lines (68 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
#embedding
class Embed(nn.Module):
def __init__(self, input_dim, embedding_dim, item = False):
super(Embed, self).__init__()
self.embedding_table = nn.Embedding(input_dim, embedding_dim)
#ensure that the representation of paddings are tensors of zeros, thus, will
#not contribute in potential average pooling session representations
if(item):
self.embedding_table.weight.data[0] = torch.zeros(embedding_dim)
def forward(self, input):
output = self.embedding_table(input)
return output
#inter session RNN module
class Inter_RNN(nn.Module):
def __init__(self, input_dim, hidden_dim, dropout_rate):
super(Inter_RNN, self).__init__()
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.dropout = nn.Dropout(dropout_rate)
self.gru = nn.GRU(self.input_dim, self.hidden_dim, batch_first=True)
def forward(self, input, hidden, rep_indicies):
input = self.dropout(input)
gru_output, _ = self.gru(input, hidden)
#find the last hidden state of each sequence in the batch which are not
hidden_indices = rep_indicies.view(-1,1,1).expand(gru_output.size(0), 1, gru_output.size(2))
hidden_out = torch.gather(gru_output,1,hidden_indices)
hidden_out = hidden_out.squeeze().unsqueeze(0)
hidden_out = self.dropout(hidden_out)
return hidden_out
def init_hidden(self, batch_size):
hidden = Variable(torch.cuda.FloatTensor(1, batch_size, self.hidden_dim).fill_(0))
return hidden
#intra session RNN module
class Intra_RNN(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate):
super(Intra_RNN, self).__init__()
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.dropout = nn.Dropout(dropout_rate)
self.gru = nn.GRU(input_dim, hidden_dim, batch_first=True)
self.linear = nn.Linear(hidden_dim, output_dim)
def forward(self, input, hidden, lengths):
input = self.dropout(input)
gru_output, _ = self.gru(input, hidden)
output = self.dropout(gru_output)
output = self.linear(output)
hidden_indices = lengths.view(-1,1,1).expand(gru_output.size(0), 1, gru_output.size(2))
hidden_out = torch.gather(gru_output,1,hidden_indices)
hidden_out = hidden_out.squeeze().unsqueeze(0)
return output, hidden_out
#time loss module
class Time_Loss(nn.Module):
def __init__(self):
super(Time_Loss, self).__init__()
self.w = nn.Parameter(torch.FloatTensor([-0.1]))
#self.w.data.uniform_(-0.1,0.1)
def forward(self, time, target, epsilon):
time_exp = torch.exp(time)
w_target = self.w*torch.pow(target, epsilon)
exps = (time_exp*(1-torch.exp(w_target)))/self.w
output = time+w_target+exps
return -output
def get_w(self):
return self.w