-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Make inner transform activation configurable for LSTMCell #10957
Conversation
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
@@ -441,6 +441,9 @@ class LSTMCell(HybridRecurrentCell): | |||
params : Parameter or None | |||
Container for weight sharing between cells. | |||
Created if `None`. | |||
in_transform_activation_type : str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This name is too verbose. They are usually called activation and recurrent activation.
activation is applied to both input and next c
Added parameters, recurrent_activation and activation. |
But the F.activation only support 4 types of activation functions. Many other activation functions (with parameters) cannot pass like tensorflow with string, such as elu/selu/prelu/leakyrelu/hard_sigmoid etc. |
Will take a look shortly. Maybe it's worth having a utility function that wraps all activation types in the most efficient way (e.g. F.tanh instead of F.activation(act_type='tanh')) so that it can be reused everywhere. |
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
@@ -473,6 +480,9 @@ def __init__(self, hidden_size, | |||
self.h2h_bias = self.params.get('h2h_bias', shape=(4*hidden_size,), | |||
init=h2h_bias_initializer, | |||
allow_deferred_init=True) | |||
self.activation = activation | |||
self.recurrent_activation = recurrent_activation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_activation
, _recurrent_activation
python/mxnet/gluon/rnn/rnn_cell.py
Outdated
F.Activation(slice_gates[1], act_type=self.recurrent_activation, name=prefix+'f') | ||
in_transform = F.Activation( | ||
slice_gates[2], act_type=self.activation, name=prefix+'c') | ||
out_gate = F.Activation(slice_gates[3], act_type=self.recurrent_activation, name=prefix+'o') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use _get_activation
@@ -255,8 +256,7 @@ def _get_activation(self, F, inputs, activation, **kwargs): | |||
return F.Activation(inputs, act_type=activation, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for string type, map the string to the most efficient operator. for example, if the string is 'tanh', instead of doing F.Activation(act_type='tanh')
, do F.tanh
, which doesn't require parsing the string at each call.
@mrkumar83 @szha Any updates? @szha if original author doesn't respond could you take this over. |
@piiswrong |
elif activation == 'sigmoid': | ||
return F.sigmoid(inputs, **kwargs) | ||
elif activation == 'relu': | ||
return F.relu(inputs, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add softsign
* Make inner activation gate configurable for LSTMCell * Adding pr feedback * Adding a recurrent_activation and activation similar to Keras * Fixing all pylint issues in the file * Adding initial pr feedback * Adding cr feedback * Adding softsign support
* Make inner activation gate configurable for LSTMCell * Adding pr feedback * Adding a recurrent_activation and activation similar to Keras * Fixing all pylint issues in the file * Adding initial pr feedback * Adding cr feedback * Adding softsign support
Description
Some papers recommend using sigmoid for the inner activation gate for an LSTM.
Other frameworks such as tensorflow alllow this:
https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell
where they have an activation parameter.
Wanted to provide something similar in MXNet.
Checklist
Essentials
Changes