Skip to content

Commit

Permalink
Allow backprop through cuDNN RNN in eval mode
Browse files Browse the repository at this point in the history
Handling of dropout descriptors has been improved too.
  • Loading branch information
apaszke committed Mar 1, 2017
1 parent 977630b commit 1487278
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 29 deletions.
36 changes: 36 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1511,6 +1511,42 @@ def test_RNN_dropout_state(self):
self.assertNotEqual(hy1, hy2)
self.assertNotEqual(hy1, hy3)

@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 5103), "needs cudnn >= 5.1")
def test_RNN_change_dropout(self):
for train, cuda in product((True, False), repeat=2):
rnn = nn.RNN(100, 100, 2, dropout=0, nonlinearity='relu')
input = Variable(torch.Tensor(3, 2, 100).uniform_())
if cuda:
input.data = input.data.cuda()
rnn.cuda()

if train:
rnn.train()
else:
rnn.eval()

prev_output = None
for p in (0, 0.5, 0, 0.7, 0.2, 1, 0.2, 0):
rnn.dropout = p
output1, hy1 = rnn(input)
output2, hy2 = rnn(input)

if p == 0 or p == 1 or not train:
self.assertEqual(output1, output2)
self.assertEqual(hy1, hy2)
else:
self.assertNotEqual(output1, output2)
self.assertNotEqual(hy1, hy2)

if prev_output is not None:
if not train:
self.assertEqual(output1.data, prev_output)
self.assertEqual(output2.data, prev_output)
else:
self.assertNotEqual(output1.data, prev_output)
self.assertNotEqual(output2.data, prev_output)
prev_output = output1.data

def _verify_pixel_shuffle(self, input, output, upscale_factor):
for c in range(output.size(1)):
for h in range(output.size(2)):
Expand Down
38 changes: 28 additions & 10 deletions torch/backends/cudnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,24 +243,42 @@ class DropoutDescriptor(object):
def __init__(self, handle, dropout, seed):
ptr = ctypes.c_void_p()
check_error(lib.cudnnCreateDropoutDescriptor(ctypes.byref(ptr)))
self._as_parameter_ = ptr

dropout_states_size = ctypes.c_long()
check_error(lib.cudnnDropoutGetStatesSize(
handle,
ctypes.byref(dropout_states_size)))

self.state = torch.cuda.ByteTensor(dropout_states_size.value)
self._as_parameter_ = ptr
self.state = None
self.dropout = dropout
self.handle = handle

self._set(dropout, seed)

def set_dropout(self, dropout, seed):
if dropout != self.dropout:
self._set(dropout, seed)

def _set(self, dropout, seed):
if self.state is None and dropout > 0:
dropout_states_size = ctypes.c_long()
check_error(lib.cudnnDropoutGetStatesSize(
self.handle,
ctypes.byref(dropout_states_size)))
self.state = torch.cuda.ByteTensor(dropout_states_size.value)
state_ptr = self.state.data_ptr()
state_size = self.state.size(0)
else:
state_ptr = None
state_size = 0

check_error(lib.cudnnSetDropoutDescriptor(
self,
handle,
self.handle,
ctypes.c_float(dropout),
ctypes.c_void_p(self.state.data_ptr()),
ctypes.c_size_t(self.state.size(0)),
ctypes.c_void_p(state_ptr),
ctypes.c_size_t(state_size),
ctypes.c_ulonglong(seed),
))

self.dropout = dropout

def __del__(self):
check_error(lib.cudnnDestroyDropoutDescriptor(self))

Expand Down
33 changes: 14 additions & 19 deletions torch/backends/cudnn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ def __setstate__(self, state):
self.inner = None


def init_dropout_descriptor(fn, handle):
return cudnn.DropoutDescriptor(
handle,
fn.dropout,
fn.dropout_seed
)


def init_rnn_descriptor(fn, handle):
dropout_desc_name = 'desc_' + str(torch.cuda.current_device())
dropout_p = fn.dropout if fn.train else 0
if (dropout_desc_name not in fn.dropout_state) or (fn.dropout_state[dropout_desc_name].get() is None):
fn.dropout_state[dropout_desc_name] = Unserializable(
cudnn.DropoutDescriptor(handle, dropout_p, fn.dropout_seed)
)
dropout_desc = fn.dropout_state[dropout_desc_name].get()
dropout_desc.set_dropout(dropout_p, fn.dropout_seed)
return cudnn.RNNDescriptor(
handle,
fn.hidden_size,
fn.num_layers,
fn.dropout_state['desc_' + str(torch.cuda.current_device())].get(),
dropout_desc,
fn.input_mode,
fn.bidirectional,
fn.mode,
Expand Down Expand Up @@ -229,11 +229,6 @@ def forward(fn, input, hx, weight, output, hy):
y = output

# init descriptors
desc_name = 'desc_' + str(torch.cuda.current_device())
if (desc_name not in fn.dropout_state) or (fn.dropout_state[desc_name].get() is None):
fn.dropout_state[desc_name] = Unserializable(
init_dropout_descriptor(fn, handle)
)
fn.rnn_desc = init_rnn_descriptor(fn, handle)
if is_input_packed:
fn.x_descs = cudnn.descriptor_sequence(x, fn.batch_sizes)
Expand Down Expand Up @@ -275,7 +270,7 @@ def forward(fn, input, hx, weight, output, hy):
ctypes.byref(workspace_size)
))
fn.workspace = torch.cuda.ByteTensor(workspace_size.value)
if fn.train:
if fn.requires_grad:
reserve_size = ctypes.c_long()
check_error(lib.cudnnGetRNNTrainingReserveSize(
handle,
Expand Down Expand Up @@ -354,8 +349,8 @@ def backward_grad(fn, input, hx, weight, output, grad_output, grad_hy, grad_inpu

if fn.dropout != 0 and cudnn.version() < 5103:
raise RuntimeError('dropout supported only in cudnn v 5.1 and above')
if not fn.train:
raise RuntimeError('backward_grad can only be called when training!')
if not fn.requires_grad:
raise RuntimeError('backward_grad can only be called when the function requires grad!')
if tuple(input.size()) != input_size:
raise RuntimeError('Expected input size {}, got {}'.format(
input_size, tuple(input.size())))
Expand Down Expand Up @@ -427,8 +422,8 @@ def backward_weight(fn, input, hx, output, weight, grad_weight):
output = output.transpose(0, 1)
input_size = _input_size(fn, input)
hidden_size = _hidden_size(fn)
if not fn.train:
raise RuntimeError('backward_weight can only be called when training!')
if not fn.requires_grad:
raise RuntimeError('backward_weight can only be called when the function requires grad!')
if fn.dropout != 0 and cudnn.version() < 5103:
raise RuntimeError('dropout supported only in cudnn v 5.1 and above')
if tuple(input.size()) != input_size:
Expand Down

0 comments on commit 1487278

Please sign in to comment.