Skip to content

Commit

Permalink
Have RNN classes pass their dtypes to their cells.
Browse files Browse the repository at this point in the history
In TF 2, this makes RNNs work properly when a non-float32 dtype is passed to them. ConvLSTM2D is still broken with non-float32 dtypes however, as it calls tf.zeros() in various places without passing the correct dtype.

PiperOrigin-RevId: 261368927
  • Loading branch information
reedwm authored and tensorflower-gardener committed Aug 2, 2019
1 parent 9274e93 commit d90e521
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 4 deletions.
3 changes: 2 additions & 1 deletion tensorflow/python/keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,8 @@ def __init__(self,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout)
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'))
super(ConvLSTM2D, self).__init__(cell,
return_sequences=return_sequences,
go_backwards=go_backwards,
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/keras/layers/gru_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def test_return_sequences_GRU(self):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))

def test_float64_GRU(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.GRU,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')

def test_dynamic_behavior_GRU(self):
num_samples = 2
timesteps = 3
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/keras/layers/gru_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,19 @@ def test_return_sequences_GRU(self):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))

def test_float64_GRU(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
rnn.GRU,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')

def test_return_states_GRU(self):
layer_class = rnn.GRU
x = np.random.random((2, 3, 4))
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/keras/layers/lstm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,19 @@ def test_return_sequences_LSTM(self):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))

def test_float64_LSTM(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.LSTM,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')

def test_static_shape_inference_LSTM(self):
# Github issue: 15165
timesteps = 3
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/python/keras/layers/lstm_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,21 @@ def test_return_sequences_LSTM(self):
},
input_shape=(num_samples, timesteps, embedding_dim))

def test_float64_LSTM(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
rnn.LSTM,
kwargs={
'units': units,
'return_sequences': True,
'dtype': 'float64'
},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')

def test_regularizers_LSTM(self):
embedding_dim = 4
layer_class = rnn.LSTM
Expand Down
9 changes: 6 additions & 3 deletions tensorflow/python/keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,7 +1362,8 @@ def __init__(self,
recurrent_constraint=recurrent_constraint,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout)
recurrent_dropout=recurrent_dropout,
dtype=kwargs.get('dtype'))
super(SimpleRNN, self).__init__(
cell,
return_sequences=return_sequences,
Expand Down Expand Up @@ -1890,7 +1891,8 @@ def __init__(self,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation,
reset_after=reset_after)
reset_after=reset_after,
dtype=kwargs.get('dtype'))
super(GRU, self).__init__(
cell,
return_sequences=return_sequences,
Expand Down Expand Up @@ -2516,7 +2518,8 @@ def __init__(self,
bias_constraint=bias_constraint,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
implementation=implementation)
implementation=implementation,
dtype=kwargs.get('dtype'))
super(LSTM, self).__init__(
cell,
return_sequences=return_sequences,
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/python/keras/layers/simplernn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ def test_return_sequences_SimpleRNN(self):
'return_sequences': True},
input_shape=(num_samples, timesteps, embedding_dim))

def test_float64_SimpleRNN(self):
num_samples = 2
timesteps = 3
embedding_dim = 4
units = 2
testing_utils.layer_test(
keras.layers.SimpleRNN,
kwargs={'units': units,
'return_sequences': True,
'dtype': 'float64'},
input_shape=(num_samples, timesteps, embedding_dim),
input_dtype='float64')

def test_dynamic_behavior_SimpleRNN(self):
num_samples = 2
timesteps = 3
Expand Down

0 comments on commit d90e521

Please sign in to comment.