From 776b239b191fb27fa763998ef0b71407ed299c47 Mon Sep 17 00:00:00 2001 From: mrkumar83 Date: Tue, 5 Jun 2018 15:52:15 -0700 Subject: [PATCH] Make inner transform activation configurable for LSTMCell (#10957) * Make inner activation gate configurable for LSTMCell * Adding pr feedback * Adding a recurrent_activation and activation similar to Keras * Fixing all pylint issues in the file * Adding initial pr feedback * Adding cr feedback * Adding softsign support --- python/mxnet/gluon/rnn/rnn_cell.py | 54 ++++++++++++++++++++++-------- tests/python/unittest/test_rnn.py | 22 +++++++----- 2 files changed, 53 insertions(+), 23 deletions(-) diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 281aba452579..f318b10812a6 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -224,6 +224,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N The new state of this RNN after this unrolling. The type of this symbol is same as the output of `begin_state()`. """ + # pylint: disable=too-many-locals self.reset() inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False) @@ -251,12 +252,19 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N #pylint: disable=no-self-use def _get_activation(self, F, inputs, activation, **kwargs): """Get activation function. Convert if is string""" - if isinstance(activation, string_types): + if activation == 'tanh': + return F.tanh(inputs, **kwargs) + elif activation == 'sigmoid': + return F.sigmoid(inputs, **kwargs) + elif activation == 'relu': + return F.relu(inputs, **kwargs) + elif activation == 'softsign': + return F.softsign(inputs, **kwargs) + elif isinstance(activation, string_types): return F.Activation(inputs, act_type=activation, **kwargs) elif isinstance(activation, LeakyReLU): return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs) - else: - return activation(inputs, **kwargs) + return activation(inputs, **kwargs) def forward(self, inputs, states): """Unrolls the recurrent cell for one time step. @@ -441,7 +449,12 @@ class LSTMCell(HybridRecurrentCell): params : Parameter or None Container for weight sharing between cells. Created if `None`. - + activation : str + Activation type to use. See nd/symbol Activation + for supported types. + recurrent_activation : str + Activation type to use for the recurrent step. See nd/symbol Activation + for supported types. Inputs: - **data**: input tensor with shape `(batch_size, input_size)`. @@ -453,10 +466,12 @@ class LSTMCell(HybridRecurrentCell): - **next_states**: a list of two output recurrent state tensors. Each has the same shape as `states`. """ + # pylint: disable=too-many-instance-attributes def __init__(self, hidden_size, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', - input_size=0, prefix=None, params=None): + input_size=0, prefix=None, params=None, activation='tanh', + recurrent_activation='sigmoid'): super(LSTMCell, self).__init__(prefix=prefix, params=params) self._hidden_size = hidden_size @@ -473,6 +488,9 @@ def __init__(self, hidden_size, self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), init=h2h_bias_initializer, allow_deferred_init=True) + self._activation = activation + self._recurrent_activation = recurrent_activation + def state_info(self, batch_size=0): return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'}, @@ -491,6 +509,7 @@ def __repr__(self): def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): + # pylint: disable=too-many-locals prefix = 't%d_'%self._counter i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias, num_hidden=self._hidden_size*4, name=prefix+'i2h') @@ -498,13 +517,17 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, num_hidden=self._hidden_size*4, name=prefix+'h2h') gates = i2h + h2h slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice') - in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i') - forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f') - in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c') - out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o') + in_gate = self._get_activation( + F, slice_gates[0], self._recurrent_activation, name=prefix+'i') + forget_gate = self._get_activation( + F, slice_gates[1], self._recurrent_activation, name=prefix+'f') + in_transform = self._get_activation( + F, slice_gates[2], self._activation, name=prefix+'c') + out_gate = self._get_activation( + F, slice_gates[3], self._recurrent_activation, name=prefix+'o') next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform, name=prefix+'state') - next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"), + next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation), name=prefix+'out') return next_h, [next_h, next_c] @@ -675,6 +698,7 @@ def __call__(self, inputs, states): def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, valid_length=None): + # pylint: disable=too-many-locals self.reset() inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None) @@ -702,6 +726,7 @@ def __len__(self): return len(self._children) def hybrid_forward(self, *args, **kwargs): + # pylint: disable=missing-docstring raise NotImplementedError @@ -755,10 +780,9 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs) if isinstance(inputs, tensor_types): return self.hybrid_forward(F, inputs, begin_state if begin_state else []) - else: - return super(DropoutCell, self).unroll( - length, inputs, begin_state=begin_state, layout=layout, - merge_outputs=merge_outputs, valid_length=None) + return super(DropoutCell, self).unroll( + length, inputs, begin_state=begin_state, layout=layout, + merge_outputs=merge_outputs, valid_length=None) class ModifierCell(HybridRecurrentCell): @@ -856,6 +880,7 @@ class ResidualCell(ModifierCell): """ def __init__(self, base_cell): + # pylint: disable=useless-super-delegation super(ResidualCell, self).__init__(base_cell) def hybrid_forward(self, F, inputs, states): @@ -924,6 +949,7 @@ def begin_state(self, **kwargs): def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, valid_length=None): + # pylint: disable=too-many-locals self.reset() inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False) diff --git a/tests/python/unittest/test_rnn.py b/tests/python/unittest/test_rnn.py index 9fe22ae72df6..52a3dcf99342 100644 --- a/tests/python/unittest/test_rnn.py +++ b/tests/python/unittest/test_rnn.py @@ -92,15 +92,19 @@ def test_rnn(): def test_lstm(): - cell = mx.rnn.LSTMCell(100, prefix='rnn_', forget_bias=1.0) - inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] - outputs, _ = cell.unroll(3, inputs) - outputs = mx.sym.Group(outputs) - assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] - assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] - - args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) - assert outs == [(10, 100), (10, 100), (10, 100)] + for activation_type in ['', 'relu', 'sigmoid', 'softrelu', 'tanh', 'softsign']: + if activation_type == '': + cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_') + else: + cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_', activation=activation_type, recurrent_activation=activation_type) + inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] + outputs, _ = cell.unroll(3, inputs) + outputs = mx.sym.Group(outputs) + assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight'] + assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output'] + + args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50)) + assert outs == [(10, 100), (10, 100), (10, 100)] def test_lstm_forget_bias():