Skip to content

Commit

Permalink
Add option to specify resampling method in load_img (keras-team#7975)
Browse files Browse the repository at this point in the history
* Add option to specify resampling method in load_img

* Fix docstring to state that bilinear is the default

* Specify interpolation as string rather than PIL constants

* Use single quotes for strings

* Show the supported interpolation methods if invalid method is specified

* Only support lanczos resampling if available

* Fix invalid PIL Image name

* Add unit tests for load_img function

* Fix indentation to comply with PEP8

* Disable explicit nearest neighbor test which is platform dependent
  • Loading branch information
ahojnnes authored and fchollet committed Oct 8, 2017
1 parent ff9118f commit c034825
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 5 deletions.
41 changes: 36 additions & 5 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@
pil_image = None


if pil_image is not None:
_PIL_INTERPOLATION_METHODS = {
'nearest': pil_image.NEAREST,
'bilinear': pil_image.BILINEAR,
'bicubic': pil_image.BICUBIC,
}
# These methods were only introduced in version 3.4.0 (2016).
if hasattr(pil_image, 'HAMMING'):
_PIL_INTERPOLATION_METHODS['hamming'] = pil_image.HAMMING
if hasattr(pil_image, 'BOX'):
_PIL_INTERPOLATION_METHODS['box'] = pil_image.BOX
# This method is new in version 1.1.3 (2013).
if hasattr(pil_image, 'LANCZOS'):
_PIL_INTERPOLATION_METHODS['lanczos'] = pil_image.LANCZOS


def random_rotation(x, rg, row_axis=1, col_axis=2, channel_axis=0,
fill_mode='nearest', cval=0.):
"""Performs a random rotation of a Numpy image tensor.
Expand Down Expand Up @@ -302,20 +318,28 @@ def img_to_array(img, data_format=None):
return x


def load_img(path, grayscale=False, target_size=None):
def load_img(path, grayscale=False, target_size=None,
interpolation='bilinear'):
"""Loads an image into PIL format.
# Arguments
path: Path to image file
grayscale: Boolean, whether to load the image as grayscale.
target_size: Either `None` (default to original size)
or tuple of ints `(img_height, img_width)`.
interpolation: Interpolation method used to resample the image if the
target size is different from that of the loaded image.
Supported methods are "nearest", "bilinear", and "bicubic".
If PIL version 1.1.3 or newer is installed, "lanczos" is also
supported. If PIL version 3.4.0 or newer is installed, "box" and
"hamming" are also supported. By default, "bilinear" is used.
# Returns
A PIL Image instance.
# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
"""
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
Expand All @@ -327,10 +351,17 @@ def load_img(path, grayscale=False, target_size=None):
else:
if img.mode != 'RGB':
img = img.convert('RGB')
if target_size:
hw_tuple = (target_size[1], target_size[0])
if img.size != hw_tuple:
img = img.resize(hw_tuple)
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]
img = img.resize(width_height_tuple, resample)
return img


Expand Down
59 changes: 59 additions & 0 deletions tests/keras/preprocessing/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,65 @@ def test_batch_standardize(self):
transformed[i] = generator.random_transform(im)
transformed = generator.standardize(transformed)

def test_load_img(self, tmpdir):
filename = str(tmpdir / 'image.png')

original_im_array = np.array(255 * np.random.rand(100, 100, 3),
dtype=np.uint8)
original_im = image.array_to_img(original_im_array, scale=False)
original_im.save(filename)

# Test that loaded image is exactly equal to original.

loaded_im = image.load_img(filename)
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == original_im_array.shape
assert np.all(loaded_im_array == original_im_array)

loaded_im = image.load_img(filename, grayscale=True)
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (original_im_array.shape[0],
original_im_array.shape[1], 1)

# Test that nothing is changed when target size is equal to original.

loaded_im = image.load_img(filename, target_size=(100, 100))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == original_im_array.shape
assert np.all(loaded_im_array == original_im_array)

loaded_im = image.load_img(filename, grayscale=True,
target_size=(100, 100))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (original_im_array.shape[0],
original_im_array.shape[1], 1)

# Test down-sampling with bilinear interpolation.

loaded_im = image.load_img(filename, target_size=(25, 25))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (25, 25, 3)

loaded_im = image.load_img(filename, grayscale=True,
target_size=(25, 25))
loaded_im_array = image.img_to_array(loaded_im)
assert loaded_im_array.shape == (25, 25, 1)

# Test down-sampling with nearest neighbor interpolation.

loaded_im_nearest = image.load_img(filename, target_size=(25, 25),
interpolation="nearest")
loaded_im_array_nearest = image.img_to_array(loaded_im_nearest)
assert loaded_im_array_nearest.shape == (25, 25, 3)
assert np.any(loaded_im_array_nearest != loaded_im_array)

# Check that exception is raised if interpolation not supported.

loaded_im = image.load_img(filename, interpolation="unsupported")
with pytest.raises(ValueError):
loaded_im = image.load_img(filename, target_size=(25, 25),
interpolation="unsupported")


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit c034825

Please sign in to comment.