Skip to content

Commit

Permalink
Unify preprocess_input implementation in applications
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Sep 9, 2017
1 parent 0f1c855 commit 0b3dc16
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 25 deletions.
15 changes: 14 additions & 1 deletion keras/applications/imagenet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,19 @@
CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json'


def preprocess_input(x, data_format=None):
def preprocess_input(x, data_format=None, mode='caffe'):
"""Preprocesses a tensor encoding a batch of images.
# Arguments
x: input Numpy tensor, 4D.
data_format: data format of the image tensor.
mode: One of "caffe", "tf".
- caffe: will convert the images from RGB to BGR,
then will zero-center each color channel with
respect to the ImageNet dataset,
without scaling.
- tf: will scale pixels between -1 and 1,
sample-wise.
# Returns
Preprocessed tensor.
Expand All @@ -22,6 +29,12 @@ def preprocess_input(x, data_format=None):
data_format = K.image_data_format()
assert data_format in {'channels_last', 'channels_first'}

if mode == 'tf':
x /= 255.
x -= 0.5
x *= 2.
return x

if data_format == 'channels_first':
if x.ndim == 3:
# 'RGB'->'BGR'
Expand Down
14 changes: 4 additions & 10 deletions keras/applications/inception_resnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@
from ..layers import MaxPooling2D
from ..utils.data_utils import get_file
from ..engine.topology import get_source_inputs
from ..applications.imagenet_utils import _obtain_input_shape
from ..applications.imagenet_utils import decode_predictions
from . import imagenet_utils
from .imagenet_utils import _obtain_input_shape
from .imagenet_utils import decode_predictions
from .. import backend as K


Expand All @@ -43,20 +44,13 @@
def preprocess_input(x):
"""Preprocesses a numpy array encoding a batch of images.
This function applies the "Inception" preprocessing which converts
the RGB values from [0, 255] to [-1, 1]. Note that this preprocessing
function is different from `imagenet_utils.preprocess_input()`.
# Arguments
x: a 4D numpy array consists of RGB values within [0, 255].
# Returns
Preprocessed array.
"""
x /= 255.
x -= 0.5
x *= 2.
return x
return imagenet_utils.preprocess_input(x, mode='tf')


def conv2d_bn(x,
Expand Down
14 changes: 10 additions & 4 deletions keras/applications/inception_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..engine.topology import get_source_inputs
from ..utils.data_utils import get_file
from .. import backend as K
from . import imagenet_utils
from .imagenet_utils import decode_predictions
from .imagenet_utils import _obtain_input_shape

Expand Down Expand Up @@ -388,7 +389,12 @@ def InceptionV3(include_top=True,


def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
"""Preprocesses a numpy array encoding a batch of images.
# Arguments
x: a 4D numpy array consists of RGB values within [0, 255].
# Returns
Preprocessed array.
"""
return imagenet_utils.preprocess_input(x, mode='tf')
18 changes: 12 additions & 6 deletions keras/applications/mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@
from ..utils.data_utils import get_file
from ..engine.topology import get_source_inputs
from ..engine import InputSpec
from ..applications.imagenet_utils import _obtain_input_shape
from ..applications.imagenet_utils import decode_predictions
from . import imagenet_utils
from .imagenet_utils import _obtain_input_shape
from .imagenet_utils import decode_predictions
from .. import backend as K


Expand All @@ -84,10 +85,15 @@ def relu6(x):


def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
"""Preprocesses a numpy array encoding a batch of images.
# Arguments
x: a 4D numpy array consists of RGB values within [0, 255].
# Returns
Preprocessed array.
"""
return imagenet_utils.preprocess_input(x, mode='tf')


class DepthwiseConv2D(Conv2D):
Expand Down
14 changes: 10 additions & 4 deletions keras/applications/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..engine.topology import get_source_inputs
from ..utils.data_utils import get_file
from .. import backend as K
from . import imagenet_utils
from .imagenet_utils import decode_predictions
from .imagenet_utils import _obtain_input_shape

Expand Down Expand Up @@ -261,7 +262,12 @@ def Xception(include_top=True, weights='imagenet',


def preprocess_input(x):
x /= 255.
x -= 0.5
x *= 2.
return x
"""Preprocesses a numpy array encoding a batch of images.
# Arguments
x: a 4D numpy array consists of RGB values within [0, 255].
# Returns
Preprocessed array.
"""
return imagenet_utils.preprocess_input(x, mode='tf')

0 comments on commit 0b3dc16

Please sign in to comment.