Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to specify resampling method in load_img #7975

Merged
merged 10 commits into from
Oct 8, 2017
41 changes: 36 additions & 5 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
pil_image = None


if pil_image is not None:
_PIL_INTERPOLATION_METHODS = {
'nearest': pil_image.NEAREST,
'bilinear': pil_image.BILINEAR,
'bicubic': pil_image.BICUBIC,
}
# These methods were only introduced in version 3.4.0 (2016).
if hasattr(pil_image, 'HAMMING'):
_PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING
if hasattr(pil_image, 'BOX'):
_PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX
# This method is new in version 1.1.3 (2013).
if hasattr(pil_image, 'LANCZOS'):
_PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS


def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0,
fill_mode='nearest', cval=0.):
"""Performs a random rotation of a Numpy image tensor.
Expand Down Expand Up @@ -302,20 +318,28 @@ def img_to_array(img, data_format=None):
return x


def load_img(path, grayscale=False, target_size=None):
def load_img(path, grayscale=False, target_size=None,
interpolation='bilinear'):
"""Loads an image into PIL format.

# Arguments
path: Path to image file
grayscale: Boolean, whether to load the image as grayscale.
target_size: Either `None` (default to original size)
or tuple of ints `(img_height, img_width)`.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "bilinear" is used.

# Returns
A PIL Image instance.

# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
"""
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
Expand All @@ -327,10 +351,17 @@ def load_img(path, grayscale=False, target_size=None):
else:
if img.mode != 'RGB':
img = img.convert('RGB')
if target_size:
hw_tuple = (target_size[1], target_size[0])
if img.size != hw_tuple:
img = img.resize(hw_tuple)
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img


Expand Down
59 changes: 59 additions & 0 deletions tests/keras/preprocessing/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,65 @@ def test_batch_standardize(self):
transformed[i] = generator.random_transform(im)
transformed = generator.standardize(transformed)

def test_load_img(self, tmpdir):
filename = str(tmpdir / 'image.png')

original_im_array = np.array(255 * np.random.rand(100, 100, 3),
dtype=np.uint8)
original_im = image.array_to_img(original_im_array, scale=False)
original_im.save(filename)

# Test that loaded image is exactly equal to original.

loaded_im = image.load_img(filename)
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == original_im_array.shape
assert np.all(loaded_im_array == original_im_array)

loaded_im = image.load_img(filename, grayscale=True)
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (original_im_array.shape[0],
original_im_array.shape[1], 1)

# Test that nothing is changed when target size is equal to original.

loaded_im = image.load_img(filename, target_size=(100, 100))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == original_im_array.shape
assert np.all(loaded_im_array == original_im_array)

loaded_im = image.load_img(filename, grayscale=True,
target_size=(100, 100))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (original_im_array.shape[0],
original_im_array.shape[1], 1)

# Test down-sampling with bilinear interpolation.

loaded_im = image.load_img(filename, target_size=(25, 25))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (25, 25, 3)

loaded_im = image.load_img(filename, grayscale=True,
target_size=(25, 25))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (25, 25, 1)

# Test down-sampling with nearest neighbor interpolation.

loaded_im_nearest = image.load_img(filename, target_size=(25, 25),
interpolation="nearest")
loaded_im_array_nearest = image.img_to_array(loaded_im_nearest)
assert loaded_im_array_nearest.shape == (25, 25, 3)
assert np.any(loaded_im_array_nearest != loaded_im_array)

# Check that exception is raised if interpolation not supported.

loaded_im = image.load_img(filename, interpolation="unsupported")
with pytest.raises(ValueError):
loaded_im = image.load_img(filename, target_size=(25, 25),
interpolation="unsupported")


if __name__ == '__main__':
pytest.main([__file__])