Skip to content

Commit

Permalink
Fix weight drop code. Modify lpLSTM custom to do weight drop properly…
Browse files Browse the repository at this point in the history
…. LSTM performance matches with paper without finetuning
  • Loading branch information
mnair committed Jun 18, 2019
1 parent d9539cd commit 5ee4d0a
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 153 deletions.
128 changes: 0 additions & 128 deletions lpLSTM_slow.py

This file was deleted.

4 changes: 3 additions & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5, dropouth=
self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]
elif rnn_type == 'lpLSTMc':
from lpLSTM_custom import lpLSTM
self.rnns = [lpLSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else ninp, 1, dropout=0, wdropout=wdrop) for l in range(nlayers)]
self.rnns = [lpLSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else ninp, 1, dropout=0) for l in range(nlayers)]
if wdrop:
self.rnns = [WeightDrop(rnn, ['weight_hh'], dropout=wdrop) for rnn in self.rnns]
elif rnn_type == 'QRNN':
from torchqrnn import QRNNLayer
self.rnns = [QRNNLayer(input_size=ninp if l == 0 else nhid, hidden_size=nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), save_prev_x=True, zoneout=0, window=2 if l == 0 else 1, output_gate=True) for l in range(nlayers)]
Expand Down
59 changes: 35 additions & 24 deletions weight_drop.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,63 @@
import torch
from torch.nn import Parameter
from functools import wraps
class BackHook(torch.nn.Module):
def __init__(self, hook):
super(BackHook, self).__init__()
self._hook = hook
self.register_backward_hook(self._backward)

def forward(self, *inp):
return inp

@staticmethod
def _backward(self, grad_in, grad_out):
self._hook()
return None


class WeightDrop(torch.nn.Module):
"""
Implements drop-connect, as per Merity et al https://arxiv.org/abs/1708.02182
"""
def __init__(self, module, weights, dropout=0, variational=False):
super(WeightDrop, self).__init__()
self.module = module
self.weights = weights
self.dropout = dropout
self.variational = variational
self._setup()

def widget_demagnetizer_y2k_edition(*args, **kwargs):
# We need to replace flatten_parameters with a nothing function
# It must be a function rather than a lambda as otherwise pickling explodes
# We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
# (╯°□°)╯︵ ┻━┻
return
self.hooker = BackHook(self._backward)

def _setup(self):
# Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
if issubclass(type(self.module), torch.nn.RNNBase):
self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

for name_w in self.weights:
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
w = getattr(self.module, name_w)
del self.module._parameters[name_w]
self.module.register_parameter(name_w + '_raw', Parameter(w.data))
self.register_parameter(name_w + '_raw', Parameter(w.data))

def _setweights(self):
for name_w in self.weights:
raw_w = getattr(self.module, name_w + '_raw')
w = None
if self.variational:
mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
if raw_w.is_cuda: mask = mask.cuda()
raw_w = getattr(self, name_w + '_raw')
if self.training:
mask = raw_w.new_ones((raw_w.size(0), 1))
mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
w = mask.expand_as(raw_w) * raw_w
setattr(self, name_w + "_mask", mask)
else:
w = torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training)
if type(w) != torch.nn.Parameter:
w = torch.nn.Parameter(w)
setattr(self.module, name_w, w)
w = raw_w
rnn_w = getattr(self.module, name_w)
rnn_w.data.copy_(w)

def _backward(self):
# transfer gradients from embeddedRNN to raw params
for name_w in self.weights:
raw_w = getattr(self, name_w + '_raw')
rnn_w = getattr(self.module, name_w)
raw_w.grad = rnn_w.grad * getattr(self, name_w + "_mask")

def forward(self, *args):
self._setweights()
return self.module.forward(*args)
return self.module(*self.hooker(*args))

if __name__ == '__main__':
import torch
Expand Down

0 comments on commit 5ee4d0a

Please sign in to comment.