Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created a function to_data_format to abstract the shape and data_format handling. #10781

Merged
merged 4 commits into from
Aug 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .common import epsilon
from .common import image_data_format
from .common import normalize_data_format
from ..utils.generic_utils import transpose_shape
from collections import defaultdict
from contextlib import contextmanager
import warnings
Expand Down Expand Up @@ -1684,12 +1685,8 @@ def conv3d_transpose(x, kernel, output_shape, strides=(1, 1, 1),
output_shape = output_shape[1:]
# in keras2, need handle output shape in different format
if data_format == 'channels_last':
shape = list(output_shape)
shape[0] = output_shape[3]
shape[1] = output_shape[0]
shape[2] = output_shape[1]
shape[3] = output_shape[2]
output_shape = tuple(shape)
output_shape = transpose_shape(output_shape, 'channels_first',
spatial_axes=(0, 1, 2))

x = C.convolution_transpose(
kernel,
Expand Down Expand Up @@ -2201,11 +2198,8 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
output_shape = output_shape[1:]
# in keras2, need handle output shape in different format
if data_format == 'channels_last':
shape = list(output_shape)
shape[0] = output_shape[2]
shape[1] = output_shape[0]
shape[2] = output_shape[1]
output_shape = tuple(shape)
output_shape = transpose_shape(output_shape, 'channels_first',
spatial_axes=(0, 1))

x = C.convolution_transpose(
kernel,
Expand Down
15 changes: 6 additions & 9 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .common import floatx
from .common import epsilon
from .common import normalize_data_format
from ..utils.generic_utils import transpose_shape
from ..utils.generic_utils import has_arg

# Legacy functions
Expand Down Expand Up @@ -2238,15 +2239,11 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
assert len(padding[1]) == 2
data_format = normalize_data_format(data_format)

if data_format == 'channels_first':
pattern = [[0, 0],
[0, 0],
list(padding[0]),
list(padding[1])]
else:
pattern = [[0, 0],
list(padding[0]), list(padding[1]),
[0, 0]]
pattern = [[0, 0],
list(padding[0]),
list(padding[1]),
[0, 0]]
pattern = transpose_shape(pattern, data_format, spatial_axes=(1, 2))
return tf.pad(x, pattern)


Expand Down
5 changes: 3 additions & 2 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .common import floatx
from .common import epsilon
from .common import normalize_data_format
from ..utils.generic_utils import transpose_shape
from ..utils.generic_utils import has_arg
# Legacy functions
from .common import set_image_dim_ordering, image_dim_ordering
Expand Down Expand Up @@ -1823,8 +1824,8 @@ def int_or_none(value):
return None
if data_format == 'channels_last':
if image_shape:
image_shape = (image_shape[0], image_shape[3],
image_shape[1], image_shape[2])
image_shape = transpose_shape(image_shape, 'channels_first',
spatial_axes=(1, 2))
if image_shape is not None:
image_shape = tuple(int_or_none(v) for v in image_shape)
return image_shape
Expand Down
17 changes: 7 additions & 10 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ..legacy.layers import Recurrent, ConvRecurrent2D
from .recurrent import RNN
from ..utils.generic_utils import has_arg
from ..utils.generic_utils import transpose_shape


class ConvRNN2D(RNN):
Expand Down Expand Up @@ -169,22 +170,18 @@ def compute_output_shape(self, input_shape):
stride=cell.strides[1],
dilation=cell.dilation_rate[1])

if cell.data_format == 'channels_first':
output_shape = input_shape[:2] + (cell.filters, rows, cols)
elif cell.data_format == 'channels_last':
output_shape = input_shape[:2] + (rows, cols, cell.filters)
output_shape = input_shape[:2] + (rows, cols, cell.filters)
output_shape = transpose_shape(output_shape, cell.data_format,
spatial_axes=(2, 3))

if not self.return_sequences:
output_shape = output_shape[:1] + output_shape[2:]

if self.return_state:
output_shape = [output_shape]
if cell.data_format == 'channels_first':
output_shape += [(input_shape[0], cell.filters, rows, cols)
for _ in range(2)]
elif cell.data_format == 'channels_last':
output_shape += [(input_shape[0], rows, cols, cell.filters)
for _ in range(2)]
base = (input_shape[0], rows, cols, cell.filters)
base = transpose_shape(base, cell.data_format, spatial_axes=(1, 2))
output_shape += [base[:] for _ in range(2)]
return output_shape

def build(self, input_shape):
Expand Down
49 changes: 49 additions & 0 deletions keras/utils/generic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,3 +532,52 @@ def slice_arrays(arrays, start=None, stop=None):
return arrays[start:stop]
else:
return [None]


def transpose_shape(shape, target_format, spatial_axes):
"""Converts a tuple or a list to the correct `data_format`.

It does so by switching the positions of its elements.

# Arguments
shape: Tuple or list, often representing shape,
corresponding to `'channels_last'`.
target_format: A string, either `'channels_first'` or `'channels_last'`.
spatial_axes: A tuple of integers.
Correspond to the indexes of the spatial axes.
For example, if you pass a shape
representing (batch_size, timesteps, rows, cols, channels),
then `spatial_axes=(2, 3)`.

# Returns
A tuple or list, with the elements permuted according
to `target_format`.

# Example
```python
>>> from keras.utils.generic_utils import transpose_shape
>>> transpose_shape((16, 128, 128, 32),'channels_first', spatial_axes=(1, 2))
(16, 32, 128, 128)
>>> transpose_shape((16, 128, 128, 32), 'channels_last', spatial_axes=(1, 2))
(16, 128, 128, 32)
>>> transpose_shape((128, 128, 32), 'channels_first', spatial_axes=(0, 1))
(32, 128, 128)
```

# Raises
ValueError: if `value` or the global `data_format` invalid.
"""
if target_format == 'channels_first':
new_values = shape[:spatial_axes[0]]
new_values += (shape[-1],)
new_values += tuple(shape[x] for x in spatial_axes)

if isinstance(shape, list):
return list(new_values)
return new_values
elif target_format == 'channels_last':
return shape
else:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(target_format))