Skip to content

Commit

Permalink
Unbroadcast batch size dimension in Theano backend (keras-team#7957)
Browse files Browse the repository at this point in the history
* Unbroadcast batch size dimension in Theano backend

Unbroadcasting batch size dimension is important when batch size is 1

* rnn unit test batch size of 1
  • Loading branch information
MPiecuch authored and fchollet committed Sep 22, 2017
1 parent 2dfba02 commit 710898f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,7 +1412,7 @@ def _step(inputs, *states):
# Theano likes to make shape==1 dimensions
# in the initial states (outputs_info) broadcastable
if len(initial_states) > 0:
initial_states[0] = T.unbroadcast(initial_states[0], 1)
initial_states[0] = T.unbroadcast(initial_states[0], 0, 1)

results, _ = theano.scan(
_step,
Expand Down
12 changes: 12 additions & 0 deletions tests/keras/layers/recurrent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,5 +552,17 @@ def test_stacked_rnn_attributes():
assert layer.get_losses_for(x) == [y]


@rnn_test
def test_batch_size_equal_one(layer_class):
inputs = Input(batch_shape=(1, timesteps, embedding_dim))
layer = layer_class(units)
outputs = layer(inputs)
model = Model(inputs, outputs)
model.compile('sgd', 'mse')
x = np.random.random((1, timesteps, embedding_dim))
y = np.random.random((1, units))
model.train_on_batch(x, y)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit 710898f

Please sign in to comment.