diff --git a/keras/layers/core.py b/keras/layers/core.py index 547032e36087..3bbf726b79d9 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -372,24 +372,17 @@ def _fix_unknown_dimension(self, input_shape, output_shape): return tuple(output_shape) def compute_output_shape(self, input_shape): - return (input_shape[0],) + self._fix_unknown_dimension( - input_shape[1:], self.target_shape) + if None in input_shape[1:]: + # input shape (partially) unknown? replace -1's with None's + return ((input_shape[0],) + + tuple(s if s != -1 else None for s in self.target_shape)) + else: + # input shape known? then we can compute the output shape + return (input_shape[0],) + self._fix_unknown_dimension( + input_shape[1:], self.target_shape) def call(self, inputs): - # In case the target shape is not fully defined, - # we need access to the shape of `inputs`. - # solution: rely on `K.int_shape`. - target_shape = self.target_shape - if -1 in target_shape: - # Target shape not fully defined. - input_shape = None - try: - input_shape = K.int_shape(inputs) - except TypeError: - pass - if input_shape is not None: - target_shape = self.compute_output_shape(input_shape)[1:] - return K.reshape(inputs, (-1,) + target_shape) + return K.reshape(inputs, (K.shape(inputs)[0],) + self.target_shape) def get_config(self): config = {'target_shape': self.target_shape} diff --git a/tests/keras/layers/core_test.py b/tests/keras/layers/core_test.py index b1283edeabed..2e4de733d39e 100644 --- a/tests/keras/layers/core_test.py +++ b/tests/keras/layers/core_test.py @@ -79,6 +79,10 @@ def test_reshape(): kwargs={'target_shape': (1, -1)}, input_shape=(3, 2, 4)) + layer_test(layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(None, None, 4)) + @keras_test def test_permute():