Skip to content

Commit

Permalink
Make inner transform activation configurable for LSTMCell (apache#10957)
Browse files Browse the repository at this point in the history
* Make inner activation gate configurable for LSTMCell

* Adding pr feedback

* Adding a recurrent_activation and activation similar to Keras

* Fixing all pylint issues in the file

* Adding initial pr feedback

* Adding cr feedback

* Adding softsign support
  • Loading branch information
mrkumar83 authored and szha committed Jun 5, 2018
1 parent d552640 commit 776b239
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
54 changes: 40 additions & 14 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
The new state of this RNN after this unrolling.
The type of this symbol is same as the output of `begin_state()`.
"""
# pylint: disable=too-many-locals
self.reset()

inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
Expand Down Expand Up @@ -251,12 +252,19 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
#pylint: disable=no-self-use
def _get_activation(self, F, inputs, activation, **kwargs):
"""Get activation function. Convert if is string"""
if isinstance(activation, string_types):
if activation == 'tanh':
return F.tanh(inputs, **kwargs)
elif activation == 'sigmoid':
return F.sigmoid(inputs, **kwargs)
elif activation == 'relu':
return F.relu(inputs, **kwargs)
elif activation == 'softsign':
return F.softsign(inputs, **kwargs)
elif isinstance(activation, string_types):
return F.Activation(inputs, act_type=activation, **kwargs)
elif isinstance(activation, LeakyReLU):
return F.LeakyReLU(inputs, act_type='leaky', slope=activation._alpha, **kwargs)
else:
return activation(inputs, **kwargs)
return activation(inputs, **kwargs)

def forward(self, inputs, states):
"""Unrolls the recurrent cell for one time step.
Expand Down Expand Up @@ -441,7 +449,12 @@ class LSTMCell(HybridRecurrentCell):
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
activation : str
Activation type to use. See nd/symbol Activation
for supported types.
recurrent_activation : str
Activation type to use for the recurrent step. See nd/symbol Activation
for supported types.
Inputs:
- **data**: input tensor with shape `(batch_size, input_size)`.
Expand All @@ -453,10 +466,12 @@ class LSTMCell(HybridRecurrentCell):
- **next_states**: a list of two output recurrent state tensors. Each has
the same shape as `states`.
"""
# pylint: disable=too-many-instance-attributes
def __init__(self, hidden_size,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
input_size=0, prefix=None, params=None):
input_size=0, prefix=None, params=None, activation='tanh',
recurrent_activation='sigmoid'):
super(LSTMCell, self).__init__(prefix=prefix, params=params)

self._hidden_size = hidden_size
Expand All @@ -473,6 +488,9 @@ def __init__(self, hidden_size,
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,),
init=h2h_bias_initializer,
allow_deferred_init=True)
self._activation = activation
self._recurrent_activation = recurrent_activation


def state_info(self, batch_size=0):
return [{'shape': (batch_size, self._hidden_size), '__layout__': 'NC'},
Expand All @@ -491,20 +509,25 @@ def __repr__(self):

def hybrid_forward(self, F, inputs, states, i2h_weight,
h2h_weight, i2h_bias, h2h_bias):
# pylint: disable=too-many-locals
prefix = 't%d_'%self._counter
i2h = F.FullyConnected(data=inputs, weight=i2h_weight, bias=i2h_bias,
num_hidden=self._hidden_size*4, name=prefix+'i2h')
h2h = F.FullyConnected(data=states[0], weight=h2h_weight, bias=h2h_bias,
num_hidden=self._hidden_size*4, name=prefix+'h2h')
gates = i2h + h2h
slice_gates = F.SliceChannel(gates, num_outputs=4, name=prefix+'slice')
in_gate = F.Activation(slice_gates[0], act_type="sigmoid", name=prefix+'i')
forget_gate = F.Activation(slice_gates[1], act_type="sigmoid", name=prefix+'f')
in_transform = F.Activation(slice_gates[2], act_type="tanh", name=prefix+'c')
out_gate = F.Activation(slice_gates[3], act_type="sigmoid", name=prefix+'o')
in_gate = self._get_activation(
F, slice_gates[0], self._recurrent_activation, name=prefix+'i')
forget_gate = self._get_activation(
F, slice_gates[1], self._recurrent_activation, name=prefix+'f')
in_transform = self._get_activation(
F, slice_gates[2], self._activation, name=prefix+'c')
out_gate = self._get_activation(
F, slice_gates[3], self._recurrent_activation, name=prefix+'o')
next_c = F._internal._plus(forget_gate * states[1], in_gate * in_transform,
name=prefix+'state')
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type="tanh"),
next_h = F._internal._mul(out_gate, F.Activation(next_c, act_type=self._activation),
name=prefix+'out')

return next_h, [next_h, next_c]
Expand Down Expand Up @@ -675,6 +698,7 @@ def __call__(self, inputs, states):

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
valid_length=None):
# pylint: disable=too-many-locals
self.reset()

inputs, _, F, batch_size = _format_sequence(length, inputs, layout, None)
Expand Down Expand Up @@ -702,6 +726,7 @@ def __len__(self):
return len(self._children)

def hybrid_forward(self, *args, **kwargs):
# pylint: disable=missing-docstring
raise NotImplementedError


Expand Down Expand Up @@ -755,10 +780,9 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
inputs, _, F, _ = _format_sequence(length, inputs, layout, merge_outputs)
if isinstance(inputs, tensor_types):
return self.hybrid_forward(F, inputs, begin_state if begin_state else [])
else:
return super(DropoutCell, self).unroll(
length, inputs, begin_state=begin_state, layout=layout,
merge_outputs=merge_outputs, valid_length=None)
return super(DropoutCell, self).unroll(
length, inputs, begin_state=begin_state, layout=layout,
merge_outputs=merge_outputs, valid_length=None)


class ModifierCell(HybridRecurrentCell):
Expand Down Expand Up @@ -856,6 +880,7 @@ class ResidualCell(ModifierCell):
"""

def __init__(self, base_cell):
# pylint: disable=useless-super-delegation
super(ResidualCell, self).__init__(base_cell)

def hybrid_forward(self, F, inputs, states):
Expand Down Expand Up @@ -924,6 +949,7 @@ def begin_state(self, **kwargs):

def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None,
valid_length=None):
# pylint: disable=too-many-locals
self.reset()

inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
Expand Down
22 changes: 13 additions & 9 deletions tests/python/unittest/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,19 @@ def test_rnn():


def test_lstm():
cell = mx.rnn.LSTMCell(100, prefix='rnn_', forget_bias=1.0)
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs)
outputs = mx.sym.Group(outputs)
assert sorted(cell.params._params.keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]
for activation_type in ['', 'relu', 'sigmoid', 'softrelu', 'tanh', 'softsign']:
if activation_type == '':
cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_')
else:
cell = mx.gluon.rnn.LSTMCell(100, prefix='rnn_', activation=activation_type, recurrent_activation=activation_type)
inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)]
outputs, _ = cell.unroll(3, inputs)
outputs = mx.sym.Group(outputs)
assert sorted(cell.collect_params().keys()) == ['rnn_h2h_bias', 'rnn_h2h_weight', 'rnn_i2h_bias', 'rnn_i2h_weight']
assert outputs.list_outputs() == ['rnn_t0_out_output', 'rnn_t1_out_output', 'rnn_t2_out_output']

args, outs, auxs = outputs.infer_shape(rnn_t0_data=(10,50), rnn_t1_data=(10,50), rnn_t2_data=(10,50))
assert outs == [(10, 100), (10, 100), (10, 100)]


def test_lstm_forget_bias():
Expand Down

0 comments on commit 776b239

Please sign in to comment.