Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Residual unroll #6397

Merged
merged 5 commits into from
May 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion python/mxnet/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,26 @@ def __call__(self, inputs, states):
output = symbol.elemwise_add(output, inputs, name="%s_plus_residual" % output.name)
return output, states

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None):
self.reset()

self.base_cell._modified = False
outputs, states = self.base_cell.unroll(length, inputs=inputs, begin_state=begin_state,
layout=layout, merge_outputs=merge_outputs)
self.base_cell._modified = True

merge_outputs = isinstance(outputs, symbol.Symbol) if merge_outputs is None else \
merge_outputs
inputs, _ = _normalize_sequence(length, inputs, layout, merge_outputs)
if merge_outputs:
outputs = symbol.elemwise_add(outputs, inputs, name="%s_plus_residual" % outputs.name)
else:
outputs = [symbol.elemwise_add(output_sym, input_sym,
name="%s_plus_residual" % output_sym.name)
for output_sym, input_sym in zip(outputs, inputs)]

return outputs, states


class BidirectionalCell(BaseRNNCell):
"""Bidirectional RNN cell.
Expand All @@ -928,9 +948,18 @@ class BidirectionalCell(BaseRNNCell):
"""
def __init__(self, l_cell, r_cell, params=None, output_prefix='bi_'):
super(BidirectionalCell, self).__init__('', params=params)
self._output_prefix = output_prefix
self._override_cell_params = params is not None

if self._override_cell_params:
assert l_cell._own_params and r_cell._own_params, \
"Either specify params for BidirectionalCell " \
"or child cells, not both."
l_cell.params._params.update(self.params._params)
r_cell.params._params.update(self.params._params)
self.params._params.update(l_cell.params._params)
self.params._params.update(r_cell.params._params)
self._cells = [l_cell, r_cell]
self._output_prefix = output_prefix

def unpack_weights(self, args):
return _cells_unpack_weights(self._cells, args)
Expand Down
20 changes: 20 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,25 @@ def test_unfuse():
check_rnn_consistency(fused, stack)
check_rnn_consistency(stack, fused)

def test_residual_fused():
cell = mx.rnn.ResidualCell(
mx.rnn.FusedRNNCell(50, num_layers=3, mode='lstm',
prefix='rnn_', dropout=0.5))

inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs, merge_outputs=None)
assert sorted(cell.params._params.keys()) == \
['rnn_parameters']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 2, 50)]
outputs = outputs.eval(ctx=mx.gpu(0),
rnn_t0_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5,
rnn_t1_data=mx.nd.ones((10, 50), ctx=mx.gpu(0))+5,
rnn_parameters=mx.nd.zeros((61200,), ctx=mx.gpu(0)))
expected_outputs = np.ones((10, 2, 50))+5
assert np.array_equal(outputs[0].asnumpy(), expected_outputs)

if __name__ == '__main__':
test_countsketch()
test_ifft()
Expand All @@ -1103,6 +1122,7 @@ def test_unfuse():
test_gru()
test_rnn()
test_unfuse()
test_residual_fused()
test_convolution_options()
test_convolution_versions()
test_convolution_with_type()
Expand Down
34 changes: 32 additions & 2 deletions tests/python/unittest/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ def test_residual():

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 50), (10, 50)]
print(args)
print(outputs.list_arguments())
outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50)),
rnn_t1_data=mx.nd.ones((10, 50)),
rnn_i2h_weight=mx.nd.zeros((150, 50)),
Expand All @@ -85,6 +83,38 @@ def test_residual():
assert np.array_equal(outputs[1].asnumpy(), expected_outputs)


def test_residual_bidirectional():
cell = mx.rnn.ResidualCell(
mx.rnn.BidirectionalCell(
mx.rnn.GRUCell(25, prefix='rnn_l_'),
mx.rnn.GRUCell(25, prefix='rnn_r_')))

inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(2)]
outputs, _ = cell.unroll(2, inputs, merge_outputs=False)
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == \
['rnn_l_h2h_bias', 'rnn_l_h2h_weight', 'rnn_l_i2h_bias', 'rnn_l_i2h_weight',
'rnn_r_h2h_bias', 'rnn_r_h2h_weight', 'rnn_r_i2h_bias', 'rnn_r_i2h_weight']
assert outputs.list_outputs() == \
['bi_t0_plus_residual_output', 'bi_t1_plus_residual_output']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10, 50), rnn_t1_data=(10, 50))
assert outs == [(10, 50), (10, 50)]
outputs = outputs.eval(rnn_t0_data=mx.nd.ones((10, 50))+5,
rnn_t1_data=mx.nd.ones((10, 50))+5,
rnn_l_i2h_weight=mx.nd.zeros((75, 50)),
rnn_l_i2h_bias=mx.nd.zeros((75,)),
rnn_l_h2h_weight=mx.nd.zeros((75, 25)),
rnn_l_h2h_bias=mx.nd.zeros((75,)),
rnn_r_i2h_weight=mx.nd.zeros((75, 50)),
rnn_r_i2h_bias=mx.nd.zeros((75,)),
rnn_r_h2h_weight=mx.nd.zeros((75, 25)),
rnn_r_h2h_bias=mx.nd.zeros((75,)))
expected_outputs = np.ones((10, 50))+5
assert np.array_equal(outputs[0].asnumpy(), expected_outputs)
assert np.array_equal(outputs[1].asnumpy(), expected_outputs)


def test_stack():
cell = mx.rnn.SequentialRNNCell()
for i in range(5):
Expand Down