Skip to content

Commit

Permalink
Allow to reshape unknown tensors. (#7960)
Browse files Browse the repository at this point in the history
* Allow to reshape unknown tensors.

* Add test for reshaping unknown shapes.
  • Loading branch information
hgaiser authored and fchollet committed Sep 27, 2017
1 parent fb4a084 commit 408de9b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
25 changes: 9 additions & 16 deletions keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,24 +372,17 @@ def _fix_unknown_dimension(self, input_shape, output_shape):
return tuple(output_shape)

def compute_output_shape(self, input_shape):
return (input_shape[0],) + self._fix_unknown_dimension(
input_shape[1:], self.target_shape)
if None in input_shape[1:]:
# input shape (partially) unknown? replace -1's with None's
return ((input_shape[0],) +
tuple(s if s != -1 else None for s in self.target_shape))
else:
# input shape known? then we can compute the output shape
return (input_shape[0],) + self._fix_unknown_dimension(
input_shape[1:], self.target_shape)

def call(self, inputs):
# In case the target shape is not fully defined,
# we need access to the shape of `inputs`.
# solution: rely on `K.int_shape`.
target_shape = self.target_shape
if -1 in target_shape:
# Target shape not fully defined.
input_shape = None
try:
input_shape = K.int_shape(inputs)
except TypeError:
pass
if input_shape is not None:
target_shape = self.compute_output_shape(input_shape)[1:]
return K.reshape(inputs, (-1,) + target_shape)
return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape)

def get_config(self):
config = {'target_shape': self.target_shape}
Expand Down
4 changes: 4 additions & 0 deletions tests/keras/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def test_reshape():
kwargs={'target_shape': (1, -1)},
input_shape=(3, 2, 4))

layer_test(layers.Reshape,
kwargs={'target_shape': (-1, 1)},
input_shape=(None, None, 4))


@keras_test
def test_permute():
Expand Down

0 comments on commit 408de9b

Please sign in to comment.