forked from manuvn/lpRNN-awd-lstm-lm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlpLSTM.py
100 lines (92 loc) · 4.27 KB
/
lpLSTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import math
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.jit as jit
"""
Reuse code from https://github.com/pytorch/pytorch/blob/master/benchmarks/fastrnns/custom_lstms.py
"""
# class lpLSTM(jit.ScriptModule):
class lpLSTM(nn.Module):
"""
An implementation of Hochreiter & Schmidhuber with dropout, weight dropout and low pass filtering added:
'Long-Short Term Memory'
http://www.bioinf.jku.at/publications/older/2604.pdf
retention_ratio: for low pass filtering the RNN
"""
def __init__(self, input_size, hidden_size, bias=True, dropout=0.0, wdropout=0.0
,activation='tanh', train_ret_ratio=False, set_retention_ratio=None):
# super(lpLSTMCell, self).__init__(mode='LSTM', input_size=input_size, hidden_size=hidden_size)
super(lpLSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias > 0
self.dropout = dropout
self.wdropout = wdropout
self.train_ret_ratio = train_ret_ratio > 0
if wdropout: #weight dropout
self.raw_w_ih = th.randn(4 * hidden_size, input_size)
self.weight_ih = Parameter(F.dropout(self.raw_w_ih, p=self.wdropout, training=self.training))
self.raw_w_hh = th.randn(4 * hidden_size, hidden_size)
self.weight_hh = Parameter(F.dropout(self.raw_w_hh, p=self.wdropout, training=self.training))
else:
self.weight_ih = Parameter(th.randn(4 * hidden_size, input_size))
self.weight_hh = Parameter(th.randn(4 * hidden_size, hidden_size))
self.bias_ih = Parameter(th.randn(4 * hidden_size), requires_grad=self.bias)
self.bias_hh = Parameter(th.randn(4 * hidden_size), requires_grad=self.bias)
# Recurrent activation
if activation =='tanh':
self.activation = th.tanh
else:
self.activation = th.relu
# Train low pass filtering factor
if set_retention_ratio is not None:
self.retention_ratio = nn.Parameter(set_retention_ratio * th.ones(self.hidden_size)
,requires_grad=self.train_ret_ratio)
else:
self.retention_ratio = nn.Parameter(th.FloatTensor(self.hidden_size).uniform_(0.001, 1)
,requires_grad=self.train_ret_ratio)
self.reset_parameters()
def reset_parameters(self):
std = 1.0 / math.sqrt(self.hidden_size)
for w in self.parameters():
w.data.uniform_(-std, std)
def forward(self, input_, hidden=None):
# input_ is of dimensionalty (time_step, batch, input_size, ...)
outputs = []
for x in th.unbind(input_, dim=0):
h = self.forward_step(x, hidden)
outputs.append(h[0].clone())
hidden = h[1]
op = th.squeeze(th.stack(outputs))
return op, hidden
def forward_step(self, input, state):
# type: (Tensor, Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
# ALERT: Bug in code here. Does not work for batch_size of 1.
hx, cx = th.squeeze(state[0]), th.squeeze(state[1])
gates = (th.mm(input, self.weight_ih.t()) + self.bias_ih +
th.mm(hx, self.weight_hh.t()) + self.bias_hh)
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
ingate = th.sigmoid(ingate)
forgetgate = th.sigmoid(forgetgate)
cellgate = self.activation(cellgate)
outgate = th.sigmoid(outgate)
cy = (self.retention_ratio * forgetgate * cx) + (ingate * cellgate)
# cy = (forgetgate * cx) + (ingate * cellgate)
hy = outgate * self.activation(cy)
# Filtering
# hy = self.retention_ratio * hx + (1-self.retention_ratio) * hy
hy = self.retention_ratio * hx + hy
# Dropout
if self.dropout > 0.0:
F.dropout(hy, p=self.dropout, training=self.training, inplace=True)
return hy, (hy, cy)
if __name__ == '__main__':
rnn = lpLSTM(input_size=10, hidden_size=20)
input = th.randn(5, 3, 10)
h0 = th.randn(1, 3, 20)
c0 = th.randn(1, 3, 20)
#output, (hn, cn) = rnn(input, (h0, c0))
x = rnn(input, (h0, c0))
print(x)