From b3370c0da4430367761e6762e85de7ab4962ff62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Mon, 9 Oct 2017 14:01:18 -0700 Subject: [PATCH] Add CuDNN GRU and LSTM layers. (#8094) * temp version of CuDNNGRU * Add CuDNN GRU and LSTM layers. * Restrict tests to GPU configs. * Fix test skipping * Fix test flake with CNTK. --- keras/layers/__init__.py | 1 + keras/layers/cudnn_recurrent.py | 521 +++++++++++++++++++++ tests/keras/layers/cudnn_recurrent_test.py | 353 ++++++++++++++ tests/keras/layers/wrappers_test.py | 2 + 4 files changed, 877 insertions(+) create mode 100644 keras/layers/cudnn_recurrent.py create mode 100644 tests/keras/layers/cudnn_recurrent_test.py diff --git a/keras/layers/__init__.py b/keras/layers/__init__.py index 0ce64c1c79d..00494f13249 100644 --- a/keras/layers/__init__.py +++ b/keras/layers/__init__.py @@ -11,6 +11,7 @@ from .pooling import * from .local import * from .recurrent import * +from .cudnn_recurrent import * from .normalization import * from .embeddings import * from .noise import * diff --git a/keras/layers/cudnn_recurrent.py b/keras/layers/cudnn_recurrent.py new file mode 100644 index 00000000000..cdf2ef52a77 --- /dev/null +++ b/keras/layers/cudnn_recurrent.py @@ -0,0 +1,521 @@ +from .. import backend as K +from .. import initializers +from .. import regularizers +from .. import constraints +from .recurrent import RNN +from ..layers import InputSpec + +from collections import namedtuple + + +class _CuDNNRNN(RNN): + """Private base class for CuDNNGRU and CuDNNLSTM. + + # Arguments + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + """ + + def __init__(self, + return_sequences=False, + return_state=False, + stateful=False, + **kwargs): + if K.backend() != 'tensorflow': + raise RuntimeError('CuDNN RNNs are only available ' + 'with the TensorFlow backend.') + super(RNN, self).__init__(**kwargs) + self.return_sequences = return_sequences + self.return_state = return_state + self.stateful = stateful + self.supports_masking = False + self.input_spec = [InputSpec(ndim=3)] + if hasattr(self.cell.state_size, '__len__'): + self.state_spec = [InputSpec(shape=(None, dim)) + for dim in self.cell.state_size] + else: + self.state_spec = InputSpec(shape=(None, self.cell.state_size)) + self._states = None + + def _canonical_to_params(self, weights, biases): + import tensorflow as tf + weights = [tf.reshape(x, (-1,)) for x in weights] + biases = [tf.reshape(x, (-1,)) for x in biases] + return tf.concat(weights + biases, 0) + + def call(self, inputs, mask=None, training=None, initial_state=None): + if isinstance(mask, list): + mask = mask[0] + if mask is not None: + raise ValueError('Masking is not supported for CuDNN RNNs.') + + # input shape: `(samples, time (padded with zeros), input_dim)` + # note that the .build() method of subclasses MUST define + # self.input_spec and self.state_spec with complete input shapes. + if isinstance(inputs, list): + initial_state = inputs[1:] + inputs = inputs[0] + elif initial_state is not None: + pass + elif self.stateful: + initial_state = self.states + else: + initial_state = self.get_initial_state(inputs) + + if len(initial_state) != len(self.states): + raise ValueError('Layer has ' + str(len(self.states)) + + ' states but was passed ' + + str(len(initial_state)) + + ' initial states.') + + output, states = self._process_batch(inputs, initial_state) + + if self.stateful: + updates = [] + for i in range(len(states)): + updates.append((self.states[i], states[i])) + self.add_update(updates, inputs) + + if self.return_state: + return [output] + states + else: + return output + + def get_config(self): + config = {'return_sequences': self.return_sequences, + 'return_state': self.return_state, + 'stateful': self.stateful} + base_config = super(RNN, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config): + return cls(**config) + + @property + def trainable_weights(self): + if self.trainable and self.built: + return [self.kernel, self.recurrent_kernel, self.bias] + return [] + + @property + def non_trainable_weights(self): + if not self.trainable and self.built: + return [self.kernel, self.recurrent_kernel, self.bias] + return [] + + @property + def losses(self): + return super(RNN, self).losses + + def get_losses_for(self, inputs=None): + return super(RNN, self).get_losses_for(inputs=inputs) + + +class CuDNNGRU(_CuDNNRNN): + """Fast GRU implementation backed by CuDNN. + + Can only be run on GPU. + + # Arguments + units: Positive integer, dimensionality of the output space. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + """ + + def __init__(self, units, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + return_sequences=False, + return_state=False, + stateful=False, + **kwargs): + self.units = units + super(CuDNNGRU, self).__init__( + return_sequences=return_sequences, + return_state=return_state, + stateful=stateful, + **kwargs) + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + @property + def cell(self): + Cell = namedtuple('cell', 'state_size') + cell = Cell(state_size=self.units) + return cell + + def build(self, input_shape): + super(CuDNNGRU, self).build(input_shape) + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_dim = input_shape[-1] + + from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops + self._cudnn_gru = cudnn_rnn_ops.CudnnGRU( + num_layers=1, + num_units=self.units, + input_size=input_dim, + input_mode='linear_input') + + self.kernel = self.add_weight(shape=(input_dim, self.units * 3), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 3), + name='recurrent_kernel', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + + self.bias = self.add_weight(shape=(self.units * 6,), + name='bias', + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + + self.kernel_z = self.kernel[:, :self.units] + self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units] + self.kernel_r = self.kernel[:, self.units: self.units * 2] + self.recurrent_kernel_r = self.recurrent_kernel[:, + self.units: + self.units * 2] + self.kernel_h = self.kernel[:, self.units * 2:] + self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:] + + self.bias_z_i = self.bias[:self.units] + self.bias_r_i = self.bias[self.units: self.units * 2] + self.bias_h_i = self.bias[self.units * 2: self.units * 3] + self.bias_z = self.bias[self.units * 3: self.units * 4] + self.bias_r = self.bias[self.units * 4: self.units * 5] + self.bias_h = self.bias[self.units * 5:] + + self.built = True + + def _process_batch(self, inputs, initial_state): + import tensorflow as tf + inputs = tf.transpose(inputs, (1, 0, 2)) + input_h = initial_state[0] + input_h = tf.expand_dims(input_h, axis=0) + + params = self._canonical_to_params( + weights=[ + self.kernel_r, + self.kernel_z, + self.kernel_h, + self.recurrent_kernel_r, + self.recurrent_kernel_z, + self.recurrent_kernel_h, + ], + biases=[ + self.bias_r_i, + self.bias_z_i, + self.bias_h_i, + self.bias_r, + self.bias_z, + self.bias_h, + ], + ) + outputs, h = self._cudnn_gru( + inputs, + input_h=input_h, + params=params, + is_training=True) + + if self.stateful or self.return_state: + h = h[0] + if self.return_sequences: + output = tf.transpose(outputs, (1, 0, 2)) + else: + output = outputs[-1] + return output, [h] + + def get_config(self): + config = { + 'units': self.units, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint)} + base_config = super(CuDNNGRU, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + +class CuDNNLSTM(_CuDNNRNN): + """Fast LSTM implementation backed by CuDNN. + + Can only be run on GPU. + + # Arguments + units: Positive integer, dimensionality of the output space. + kernel_initializer: Initializer for the `kernel` weights matrix, + used for the linear transformation of the inputs. + (see [initializers](../initializers.md)). + unit_forget_bias: Boolean. + If True, add 1 to the bias of the forget gate at initialization. + Setting it to true will also force `bias_initializer="zeros"`. + This is recommended in [Jozefowicz et al.](http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf) + recurrent_initializer: Initializer for the `recurrent_kernel` + weights matrix, + used for the linear transformation of the recurrent state. + (see [initializers](../initializers.md)). + bias_initializer: Initializer for the bias vector + (see [initializers](../initializers.md)). + kernel_regularizer: Regularizer function applied to + the `kernel` weights matrix + (see [regularizer](../regularizers.md)). + recurrent_regularizer: Regularizer function applied to + the `recurrent_kernel` weights matrix + (see [regularizer](../regularizers.md)). + bias_regularizer: Regularizer function applied to the bias vector + (see [regularizer](../regularizers.md)). + activity_regularizer: Regularizer function applied to + the output of the layer (its "activation"). + (see [regularizer](../regularizers.md)). + kernel_constraint: Constraint function applied to + the `kernel` weights matrix + (see [constraints](../constraints.md)). + recurrent_constraint: Constraint function applied to + the `recurrent_kernel` weights matrix + (see [constraints](../constraints.md)). + bias_constraint: Constraint function applied to the bias vector + (see [constraints](../constraints.md)). + return_sequences: Boolean. Whether to return the last output. + in the output sequence, or the full sequence. + return_state: Boolean. Whether to return the last state + in addition to the output. + stateful: Boolean (default False). If True, the last state + for each sample at index i in a batch will be used as initial + state for the sample of index i in the following batch. + """ + def __init__(self, units, + kernel_initializer='glorot_uniform', + recurrent_initializer='orthogonal', + bias_initializer='zeros', + unit_forget_bias=True, + kernel_regularizer=None, + recurrent_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + kernel_constraint=None, + recurrent_constraint=None, + bias_constraint=None, + return_sequences=False, + return_state=False, + stateful=False, + **kwargs): + self.units = units + super(CuDNNLSTM, self).__init__( + return_sequences=return_sequences, + return_state=return_state, + stateful=stateful, + **kwargs) + + self.kernel_initializer = initializers.get(kernel_initializer) + self.recurrent_initializer = initializers.get(recurrent_initializer) + self.bias_initializer = initializers.get(bias_initializer) + self.unit_forget_bias = unit_forget_bias + + self.kernel_regularizer = regularizers.get(kernel_regularizer) + self.recurrent_regularizer = regularizers.get(recurrent_regularizer) + self.bias_regularizer = regularizers.get(bias_regularizer) + self.activity_regularizer = regularizers.get(activity_regularizer) + + self.kernel_constraint = constraints.get(kernel_constraint) + self.recurrent_constraint = constraints.get(recurrent_constraint) + self.bias_constraint = constraints.get(bias_constraint) + + @property + def cell(self): + Cell = namedtuple('cell', 'state_size') + cell = Cell(state_size=(self.units, self.units)) + return cell + + def build(self, input_shape): + super(CuDNNLSTM, self).build(input_shape) + if isinstance(input_shape, list): + input_shape = input_shape[0] + input_dim = input_shape[-1] + + from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops + self._cudnn_lstm = cudnn_rnn_ops.CudnnLSTM( + num_layers=1, + num_units=self.units, + input_size=input_dim, + input_mode='linear_input') + + self.kernel = self.add_weight(shape=(input_dim, self.units * 4), + name='kernel', + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + constraint=self.kernel_constraint) + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units * 4), + name='recurrent_kernel', + initializer=self.recurrent_initializer, + regularizer=self.recurrent_regularizer, + constraint=self.recurrent_constraint) + + if self.unit_forget_bias: + def bias_initializer(shape, *args, **kwargs): + return K.concatenate([ + self.bias_initializer((self.units * 5,), *args, **kwargs), + initializers.Ones()((self.units,), *args, **kwargs), + self.bias_initializer((self.units * 2,), *args, **kwargs), + ]) + else: + bias_initializer = self.bias_initializer + self.bias = self.add_weight(shape=(self.units * 8,), + name='bias', + initializer=bias_initializer, + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + + self.kernel_i = self.kernel[:, :self.units] + self.kernel_f = self.kernel[:, self.units: self.units * 2] + self.kernel_c = self.kernel[:, self.units * 2: self.units * 3] + self.kernel_o = self.kernel[:, self.units * 3:] + + self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] + self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: self.units * 2] + self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: self.units * 3] + self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] + + self.bias_i_i = self.bias[:self.units] + self.bias_f_i = self.bias[self.units: self.units * 2] + self.bias_c_i = self.bias[self.units * 2: self.units * 3] + self.bias_o_i = self.bias[self.units * 3: self.units * 4] + self.bias_i = self.bias[self.units * 4: self.units * 5] + self.bias_f = self.bias[self.units * 5: self.units * 6] + self.bias_c = self.bias[self.units * 6: self.units * 7] + self.bias_o = self.bias[self.units * 7:] + + self.built = True + + def _process_batch(self, inputs, initial_state): + import tensorflow as tf + inputs = tf.transpose(inputs, (1, 0, 2)) + input_h = initial_state[0] + input_c = initial_state[1] + input_h = tf.expand_dims(input_h, axis=0) + input_c = tf.expand_dims(input_c, axis=0) + + params = self._canonical_to_params( + weights=[ + self.kernel_i, + self.kernel_f, + self.kernel_c, + self.kernel_o, + self.recurrent_kernel_i, + self.recurrent_kernel_f, + self.recurrent_kernel_c, + self.recurrent_kernel_o, + ], + biases=[ + self.bias_i_i, + self.bias_f_i, + self.bias_c_i, + self.bias_o_i, + self.bias_i, + self.bias_f, + self.bias_c, + self.bias_o, + ], + ) + outputs, h, c = self._cudnn_lstm( + inputs, + input_h=input_h, + input_c=input_c, + params=params, + is_training=True) + + if self.stateful or self.return_state: + h = h[0] + c = c[0] + if self.return_sequences: + output = tf.transpose(outputs, (1, 0, 2)) + else: + output = outputs[-1] + return output, [h, c] + + def get_config(self): + config = { + 'units': self.units, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), + 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'unit_forget_bias': self.unit_forget_bias, + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), + 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'activity_regularizer': regularizers.serialize(self.activity_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), + 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), + 'bias_constraint': constraints.serialize(self.bias_constraint)} + base_config = super(CuDNNLSTM, self).get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/tests/keras/layers/cudnn_recurrent_test.py b/tests/keras/layers/cudnn_recurrent_test.py new file mode 100644 index 00000000000..24ea8daef21 --- /dev/null +++ b/tests/keras/layers/cudnn_recurrent_test.py @@ -0,0 +1,353 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose +import keras +from keras.utils.test_utils import layer_test +from keras.utils.test_utils import keras_test +import time + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_cudnn_rnn_canonical_to_params_lstm(): + units = 1 + input_size = 1 + layer = keras.layers.CuDNNLSTM(units) + layer.build((None, None, input_size)) + + params = layer._canonical_to_params( + weights=[ + layer.kernel_i, + layer.kernel_f, + layer.kernel_c, + layer.kernel_o, + layer.recurrent_kernel_i, + layer.recurrent_kernel_f, + layer.recurrent_kernel_c, + layer.recurrent_kernel_o, + ], + biases=[ + layer.bias_i_i, + layer.bias_f_i, + layer.bias_c_i, + layer.bias_o_i, + layer.bias_i, + layer.bias_f, + layer.bias_c, + layer.bias_o, + ], + ) + ref_params = layer._cudnn_lstm.canonical_to_params( + weights=[ + layer.kernel_i, + layer.kernel_f, + layer.kernel_c, + layer.kernel_o, + layer.recurrent_kernel_i, + layer.recurrent_kernel_f, + layer.recurrent_kernel_c, + layer.recurrent_kernel_o, + ], + biases=[ + layer.bias_i_i, + layer.bias_f_i, + layer.bias_c_i, + layer.bias_o_i, + layer.bias_i, + layer.bias_f, + layer.bias_c, + layer.bias_o, + ], + ) + ref_params_value = keras.backend.get_value(ref_params) + params_value = keras.backend.get_value(params) + diff = np.mean(ref_params_value - params_value) + assert diff < 1e-8 + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_cudnn_rnn_canonical_to_params_gru(): + units = 7 + input_size = 9 + layer = keras.layers.CuDNNGRU(units) + layer.build((None, None, input_size)) + + ref_params = layer._cudnn_gru.canonical_to_params( + weights=[ + layer.kernel_r, + layer.kernel_z, + layer.kernel_h, + layer.recurrent_kernel_r, + layer.recurrent_kernel_z, + layer.recurrent_kernel_h, + ], + biases=[ + layer.bias_r_i, + layer.bias_z_i, + layer.bias_h_i, + layer.bias_r, + layer.bias_z, + layer.bias_h, + ], + ) + params = layer._canonical_to_params( + weights=[ + layer.kernel_r, + layer.kernel_z, + layer.kernel_h, + layer.recurrent_kernel_r, + layer.recurrent_kernel_z, + layer.recurrent_kernel_h, + ], + biases=[ + layer.bias_r_i, + layer.bias_z_i, + layer.bias_h_i, + layer.bias_r, + layer.bias_z, + layer.bias_h, + ], + ) + ref_params_value = keras.backend.get_value(ref_params) + params_value = keras.backend.get_value(params) + diff = np.mean(ref_params_value - params_value) + assert diff < 1e-8 + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_cudnn_rnn_timing(): + input_size = 1000 + timesteps = 60 + units = 256 + num_samples = 10000 + + times = [] + for rnn_type in ['lstm', 'gru']: + for use_cudnn in [True, False]: + start_time = time.time() + inputs = keras.layers.Input(shape=(None, input_size)) + if use_cudnn: + if rnn_type == 'lstm': + layer = keras.layers.CuDNNLSTM(units) + else: + layer = keras.layers.CuDNNGRU(units) + else: + if rnn_type == 'lstm': + layer = keras.layers.LSTM(units) + else: + layer = keras.layers.GRU(units) + outputs = layer(inputs) + + model = keras.models.Model(inputs, outputs) + model.compile('sgd', 'mse') + + x = np.random.random((num_samples, timesteps, input_size)) + y = np.random.random((num_samples, units)) + model.fit(x, y, epochs=4, batch_size=32) + + times.append(time.time() - start_time) + + speedup = times[1] / times[0] + print(rnn_type, 'speedup', speedup) + assert speedup > 3 + keras.backend.clear_session() + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_cudnn_rnn_basics(): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + for return_sequences in [True, False]: + with keras.utils.CustomObjectScope( + {'keras.layers.CuDNNGRU': keras.layers.CuDNNGRU, + 'keras.layers.CuDNNLSTM': keras.layers.CuDNNLSTM}): + layer_test( + layer_class, + kwargs={'units': units, + 'return_sequences': return_sequences}, + input_shape=(num_samples, timesteps, input_size)) + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_trainability(): + input_size = 10 + units = 2 + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + layer = layer_class(units) + layer.build((None, None, input_size)) + assert len(layer.weights) == 3 + assert len(layer.trainable_weights) == 3 + assert len(layer.non_trainable_weights) == 0 + layer.trainable = False + assert len(layer.weights) == 3 + assert len(layer.non_trainable_weights) == 3 + assert len(layer.trainable_weights) == 0 + layer.trainable = True + assert len(layer.weights) == 3 + assert len(layer.trainable_weights) == 3 + assert len(layer.non_trainable_weights) == 0 + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_regularizer(): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + layer = layer_class(units, return_sequences=False, + input_shape=(timesteps, input_size), + kernel_regularizer=keras.regularizers.l1(0.01), + recurrent_regularizer=keras.regularizers.l1(0.01), + bias_regularizer='l2') + layer.build((None, None, input_size)) + assert len(layer.losses) == 3 + + layer = layer_class(units, return_sequences=False, + input_shape=(timesteps, input_size), + activity_regularizer='l2') + assert layer.activity_regularizer + x = keras.backend.variable(np.ones((num_samples, + timesteps, + input_size))) + layer(x) + assert len(layer.get_losses_for(x)) == 1 + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_return_state(): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1 + + inputs = keras.Input(batch_shape=(num_samples, timesteps, input_size)) + layer = layer_class(units, return_state=True, stateful=True) + outputs = layer(inputs) + output, state = outputs[0], outputs[1:] + assert len(state) == num_states + model = keras.models.Model(inputs, state[0]) + + inputs = np.random.random((num_samples, timesteps, input_size)) + state = model.predict(inputs) + np.testing.assert_allclose( + keras.backend.eval(layer.states[0]), state, atol=1e-4) + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_specify_initial_state_keras_tensor(): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + num_states = 2 if layer_class is keras.layers.CuDNNLSTM else 1 + + inputs = keras.Input((timesteps, input_size)) + initial_state = [keras.Input((units,)) for _ in range(num_states)] + layer = layer_class(units) + if len(initial_state) == 1: + output = layer(inputs, initial_state=initial_state[0]) + else: + output = layer(inputs, initial_state=initial_state) + assert initial_state[0] in layer.inbound_nodes[0].input_tensors + + model = keras.models.Model([inputs] + initial_state, output) + model.compile(loss='categorical_crossentropy', optimizer='adam') + + inputs = np.random.random((num_samples, timesteps, input_size)) + initial_state = [np.random.random((num_samples, units)) + for _ in range(num_states)] + targets = np.random.random((num_samples, units)) + model.fit([inputs] + initial_state, targets) + + +@keras_test +@pytest.mark.skipif((keras.backend.backend() != 'tensorflow'), + reason='Requires TensorFlow backend') +@pytest.mark.skipif(not keras.backend.tensorflow_backend._get_available_gpus(), + reason='Requires GPU') +def test_statefulness(): + input_size = 10 + timesteps = 6 + units = 2 + num_samples = 32 + + for layer_class in [keras.layers.CuDNNGRU, keras.layers.CuDNNLSTM]: + model = keras.models.Sequential() + model.add(keras.layers.Embedding(10, input_size, + input_length=timesteps, + batch_input_shape=(num_samples, + timesteps))) + layer = layer_class(units, + return_sequences=False, + stateful=True, + weights=None) + model.add(layer) + model.compile(optimizer='sgd', loss='mse') + out1 = model.predict(np.ones((num_samples, timesteps))) + assert(out1.shape == (num_samples, units)) + + # train once so that the states change + model.train_on_batch(np.ones((num_samples, timesteps)), + np.ones((num_samples, units))) + out2 = model.predict(np.ones((num_samples, timesteps))) + + # if the state is not reset, output should be different + assert(out1.max() != out2.max()) + + # check that output changes after states are reset + # (even though the model itself didn't change) + layer.reset_states() + out3 = model.predict(np.ones((num_samples, timesteps))) + assert(out2.max() != out3.max()) + + # check that container-level reset_states() works + model.reset_states() + out4 = model.predict(np.ones((num_samples, timesteps))) + assert_allclose(out3, out4, atol=1e-5) + + # check that the call to `predict` updated the states + out5 = model.predict(np.ones((num_samples, timesteps))) + assert(out4.max() != out5.max()) + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/tests/keras/layers/wrappers_test.py b/tests/keras/layers/wrappers_test.py index 8545ee75123..45bffc7c6e7 100644 --- a/tests/keras/layers/wrappers_test.py +++ b/tests/keras/layers/wrappers_test.py @@ -111,6 +111,8 @@ def test_TimeDistributed(): @keras_test +@pytest.mark.skipif((K.backend() == 'cntk'), + reason='Flaky with CNTK backend') def test_TimeDistributed_learning_phase(): # test layers that need learning_phase to be set np.random.seed(1234)