diff --git a/keras/applications/imagenet_utils.py b/keras/applications/imagenet_utils.py index 09ffee3ff74..7c131667951 100644 --- a/keras/applications/imagenet_utils.py +++ b/keras/applications/imagenet_utils.py @@ -8,12 +8,19 @@ CLASS_INDEX_PATH = 'https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json' -def preprocess_input(x, data_format=None): +def preprocess_input(x, data_format=None, mode='caffe'): """Preprocesses a tensor encoding a batch of images. # Arguments x: input Numpy tensor, 4D. data_format: data format of the image tensor. + mode: One of "caffe", "tf". + - caffe: will convert the images from RGB to BGR, + then will zero-center each color channel with + respect to the ImageNet dataset, + without scaling. + - tf: will scale pixels between -1 and 1, + sample-wise. # Returns Preprocessed tensor. @@ -22,6 +29,12 @@ def preprocess_input(x, data_format=None): data_format = K.image_data_format() assert data_format in {'channels_last', 'channels_first'} + if mode == 'tf': + x /= 255. + x -= 0.5 + x *= 2. + return x + if data_format == 'channels_first': if x.ndim == 3: # 'RGB'->'BGR' diff --git a/keras/applications/inception_resnet_v2.py b/keras/applications/inception_resnet_v2.py index 699486fedc6..dd5a1284479 100644 --- a/keras/applications/inception_resnet_v2.py +++ b/keras/applications/inception_resnet_v2.py @@ -32,8 +32,9 @@ from ..layers import MaxPooling2D from ..utils.data_utils import get_file from ..engine.topology import get_source_inputs -from ..applications.imagenet_utils import _obtain_input_shape -from ..applications.imagenet_utils import decode_predictions +from . import imagenet_utils +from .imagenet_utils import _obtain_input_shape +from .imagenet_utils import decode_predictions from .. import backend as K @@ -43,20 +44,13 @@ def preprocess_input(x): """Preprocesses a numpy array encoding a batch of images. - This function applies the "Inception" preprocessing which converts - the RGB values from [0, 255] to [-1, 1]. Note that this preprocessing - function is different from `imagenet_utils.preprocess_input()`. - # Arguments x: a 4D numpy array consists of RGB values within [0, 255]. # Returns Preprocessed array. """ - x /= 255. - x -= 0.5 - x *= 2. - return x + return imagenet_utils.preprocess_input(x, mode='tf') def conv2d_bn(x, diff --git a/keras/applications/inception_v3.py b/keras/applications/inception_v3.py index a7fc259fb78..4625722afb1 100644 --- a/keras/applications/inception_v3.py +++ b/keras/applications/inception_v3.py @@ -29,6 +29,7 @@ from ..engine.topology import get_source_inputs from ..utils.data_utils import get_file from .. import backend as K +from . import imagenet_utils from .imagenet_utils import decode_predictions from .imagenet_utils import _obtain_input_shape @@ -388,7 +389,12 @@ def InceptionV3(include_top=True, def preprocess_input(x): - x /= 255. - x -= 0.5 - x *= 2. - return x + """Preprocesses a numpy array encoding a batch of images. + + # Arguments + x: a 4D numpy array consists of RGB values within [0, 255]. + + # Returns + Preprocessed array. + """ + return imagenet_utils.preprocess_input(x, mode='tf') diff --git a/keras/applications/mobilenet.py b/keras/applications/mobilenet.py index 3342625e526..b309a42d1dd 100644 --- a/keras/applications/mobilenet.py +++ b/keras/applications/mobilenet.py @@ -71,8 +71,9 @@ from ..utils.data_utils import get_file from ..engine.topology import get_source_inputs from ..engine import InputSpec -from ..applications.imagenet_utils import _obtain_input_shape -from ..applications.imagenet_utils import decode_predictions +from . import imagenet_utils +from .imagenet_utils import _obtain_input_shape +from .imagenet_utils import decode_predictions from .. import backend as K @@ -84,10 +85,15 @@ def relu6(x): def preprocess_input(x): - x /= 255. - x -= 0.5 - x *= 2. - return x + """Preprocesses a numpy array encoding a batch of images. + + # Arguments + x: a 4D numpy array consists of RGB values within [0, 255]. + + # Returns + Preprocessed array. + """ + return imagenet_utils.preprocess_input(x, mode='tf') class DepthwiseConv2D(Conv2D): diff --git a/keras/applications/xception.py b/keras/applications/xception.py index 2982ef2c423..c559776afc7 100644 --- a/keras/applications/xception.py +++ b/keras/applications/xception.py @@ -36,6 +36,7 @@ from ..engine.topology import get_source_inputs from ..utils.data_utils import get_file from .. import backend as K +from . import imagenet_utils from .imagenet_utils import decode_predictions from .imagenet_utils import _obtain_input_shape @@ -261,7 +262,12 @@ def Xception(include_top=True, weights='imagenet', def preprocess_input(x): - x /= 255. - x -= 0.5 - x *= 2. - return x + """Preprocesses a numpy array encoding a batch of images. + + # Arguments + x: a 4D numpy array consists of RGB values within [0, 255]. + + # Returns + Preprocessed array. + """ + return imagenet_utils.preprocess_input(x, mode='tf')