forked from salesforce/awd-lstm-lm
-
Notifications
You must be signed in to change notification settings - Fork 4
/
weight_drop.py
117 lines (89 loc) · 3.32 KB
/
weight_drop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# Weight drop implementation by sdraper-CS
# from https://github.com/salesforce/awd-lstm-lm/issues/86#issuecomment-447910610
import torch
from torch.nn import Parameter
from functools import wraps
class BackHook(torch.nn.Module):
def __init__(self, hook):
super(BackHook, self).__init__()
self._hook = hook
self.register_backward_hook(self._backward)
def forward(self, *inp):
return inp
@staticmethod
def _backward(self, grad_in, grad_out):
self._hook()
return None
class WeightDrop(torch.nn.Module):
"""
Implements drop-connect, as per Merity et al https://arxiv.org/abs/1708.02182
"""
def __init__(self, module, weights, dropout=0, variational=False):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self.variational = variational
self._setup()
self.wdrop = BackHook(self._backward)
def _setup(self):
for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
self.register_parameter(name_w + '_raw', Parameter(w.data))
def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self, name_w + '_raw')
if self.training:
mask = raw_w.new_ones((raw_w.size(0), 1))
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
setattr(self, name_w + "_mask", mask)
else:
w = raw_w
rnn_w = getattr(self.module, name_w)
rnn_w.data.copy_(w)
def _backward(self):
# transfer gradients from embeddedRNN to raw params
for name_w in self.weights:
raw_w = getattr(self, name_w + '_raw')
rnn_w = getattr(self.module, name_w)
raw_w.grad = rnn_w.grad * getattr(self, name_w + "_mask")
def forward(self, *args):
self._setweights()
return self.module(*self.wdrop(*args))
if __name__ == '__main__':
import torch
from weight_drop import WeightDrop
# Input is (seq, batch, input)
x = torch.autograd.Variable(torch.randn(2, 1, 10)).cuda()
h0 = None
###
print('Testing WeightDrop')
print('=-=-=-=-=-=-=-=-=-=')
###
print('Testing WeightDrop with Linear')
lin = WeightDrop(torch.nn.Linear(10, 10), ['weight'], dropout=0.9)
lin.cuda()
run1 = [x.sum() for x in lin(x).data]
run2 = [x.sum() for x in lin(x).data]
print('All items should be different')
print('Run 1:', run1)
print('Run 2:', run2)
assert run1[0] != run2[0]
assert run1[1] != run2[1]
print('---')
###
print('Testing WeightDrop with LSTM')
wdrnn = WeightDrop(torch.nn.LSTM(10, 10), ['weight_hh_l0'], dropout=0.9)
wdrnn.cuda()
run1 = [x.sum() for x in wdrnn(x, h0)[0].data]
run2 = [x.sum() for x in wdrnn(x, h0)[0].data]
print('First timesteps should be equal, all others should differ')
print('Run 1:', run1)
print('Run 2:', run2)
# First time step, not influenced by hidden to hidden weights, should be equal
assert run1[0] == run2[0]
# Second step should not
assert run1[1] != run2[1]
print('---')