Skip to content

Commit

Permalink
Add decode_image Op (tensorflow#4222)
Browse files Browse the repository at this point in the history
* Add decode_image Op

* Update decode_image to use tf.substr
* Update JPEG magic numbers
* Add GIF functionality
* Add GIF tests, documentation
* Rearrange imports to be alphabetical

* Adjust decode_image, tests

* Tests now use files in tensorflow/core/lib
* decode_image now uses `cond` operations instead of a `case`
* There is now only one substr Op for the header of the image contents

It's not clean yet- we're still getting errors- just want to push this
before I start making more changes.

* Move from image_ops.py => image_ops_impl.py

* Add gen_image_ops import

* Clean implementation, add rgb_to_grayscale for gif

* Prevent GIF channels being set to 1, add test

* Add test for invalid channel value

* Update args/returns documentation

* Make changes to try and fix Windows build
  • Loading branch information
samjabrahams authored and girving committed Dec 1, 2016
1 parent fb04d0f commit db49056
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 1 deletion.
2 changes: 1 addition & 1 deletion tensorflow/contrib/cmake/tf_tests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ function(AddPythonTests)
endif(_AT_DEPENDS)

foreach(sourcefile ${_AT_SOURCES})
add_test(NAME ${sourcefile} COMMAND ${PYTHON_EXECUTABLE} ${sourcefile})
add_test(NAME ${sourcefile} COMMAND ${PYTHON_EXECUTABLE} ${sourcefile} WORKING_DIRECTORY ${tensorflow_source_dir})
if (_AT_DEPENDS)
add_dependencies(${_AT_TARGET} ${_AT_DEPENDS})
endif()
Expand Down
8 changes: 8 additions & 0 deletions tensorflow/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ tf_py_test(
additional_deps = ["//tensorflow:tensorflow_py"],
)

tf_py_test(
name = "decode_image_op_test",
size = "small",
srcs = ["decode_image_op_test.py"],
additional_deps = ["//tensorflow:tensorflow_py"],
data = ["//tensorflow/core:image_testdata"],
)

tf_py_test(
name = "decode_raw_op_test",
size = "small",
Expand Down
106 changes: 106 additions & 0 deletions tensorflow/python/kernel_tests/decode_image_op_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for decode_image."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import numpy as np
import tensorflow as tf


class DecodeImageOpTest(tf.test.TestCase):

def testGif(self):
# Read some real GIFs
path = os.path.join('tensorflow', 'core', 'lib', 'gif', 'testdata',
'scan.gif')
WIDTH = 20
HEIGHT = 40
STRIDE = 5
shape = (12, HEIGHT, WIDTH, 3)

with self.test_session(use_gpu=True) as sess:
gif0 = tf.read_file(path)
image0 = tf.image.decode_image(gif0)
image1 = tf.image.decode_gif(gif0)
gif0, image0, image1 = sess.run([gif0, image0, image1])

self.assertEqual(image0.shape, shape)
self.assertAllEqual(image0, image1)

for frame_idx, frame in enumerate(image0):
gt = np.zeros(shape[1:], dtype=np.uint8)
start = frame_idx * STRIDE
end = (frame_idx + 1) * STRIDE
if end <= WIDTH:
gt[:, start:end, :] = 255
else:
start -= WIDTH
end -= WIDTH
gt[start:end, :, :] = 255

self.assertAllClose(frame, gt)

bad_channels = tf.image.decode_image(gif0, channels=1)
with self.assertRaises(tf.errors.InvalidArgumentError):
bad_channels.eval()


def testJpeg(self):
# Read a real jpeg and verify shape
path = os.path.join('tensorflow', 'core', 'lib', 'jpeg', 'testdata',
'jpeg_merge_test1.jpg')
with self.test_session(use_gpu=True) as sess:
jpeg0 = tf.read_file(path)
image0 = tf.image.decode_image(jpeg0)
image1 = tf.image.decode_jpeg(jpeg0)
jpeg0, image0, image1 = sess.run([jpeg0, image0, image1])
self.assertEqual(len(jpeg0), 3771)
self.assertEqual(image0.shape, (256, 128, 3))
self.assertAllEqual(image0, image1)

def testPng(self):
# Read some real PNGs, converting to different channel numbers
prefix = ['tensorflow', 'core', 'lib', 'png', 'testdata']
inputs = [(1, 'lena_gray.png')]
for channels_in, filename in inputs:
for channels in 0, 1, 3:
with self.test_session(use_gpu=True) as sess:
path = prefix + [filename]
png0 = tf.read_file(os.path.join(*path))
image0 = tf.image.decode_image(png0, channels=channels)
image1 = tf.image.decode_png(png0, channels=channels)
png0, image0, image1 = sess.run([png0, image0, image1])
self.assertEqual(image0.shape, (26, 51, channels or channels_in))
self.assertAllEqual(image0, image1)

def testInvalidBytes(self):
image_bytes = b'ThisIsNotAnImage!'
decode = tf.image.decode_image(image_bytes)
with self.test_session():
with self.assertRaises(tf.errors.InvalidArgumentError):
decode.eval()

def testInvalidChannels(self):
image_bytes = b'unused'
with self.assertRaises(ValueError):
decode = tf.image.decode_image(image_bytes, channels=4)


if __name__ == "__main__":
tf.test.main()
2 changes: 2 additions & 0 deletions tensorflow/python/ops/image_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
@@decode_png
@@encode_png
@@decode_image
## Resizing
The resizing Ops accept input images as tensors of several types. They always
Expand Down
56 changes: 56 additions & 0 deletions tensorflow/python/ops/image_ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_image_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
Expand Down Expand Up @@ -1180,3 +1181,58 @@ def adjust_saturation(image, saturation_factor, name=None):
rgb_altered = gen_image_ops.hsv_to_rgb(hsv_altered)

return convert_image_dtype(rgb_altered, orig_dtype)


def decode_image(contents, channels=None, name=None):
"""Convenience function for `decode_gif`, `decode_jpeg`, and `decode_png`.
Detects whether an image is a GIF, JPEG, or PNG, and performs the appropriate
operation to convert the input bytes `string` into a `Tensor` of type `uint8`.
Note: `decode_gif` returns a 4-D array `[num_frames, height, width, 3]`, as
opposed to `decode_jpeg` and `decode_png`, which return 3-D arrays
`[height, width, num_channels]`. Make sure to take this into account when
constructing your graph if you are intermixing GIF files with JPEG and/or PNG
files.
Args:
contents: 0-D `string`. The encoded image bytes.
channels: An optional `int`. Defaults to `0`. Number of color channels for
the decoded image.
name: A name for the operation (optional)
Returns:
`Tensor` with type `uint8` with shape `[height, width, num_channels]` for
JPEG and PNG images and shape `[num_frames, height, width, 3]` for GIF
images.
"""
with ops.name_scope(name, 'decode_image') as scope:
if channels not in (None, 0, 1, 3):
raise ValueError('channels must be in (None, 0, 1, 3)')
substr = string_ops.substr(contents, 0, 4)

def _gif():
# Create assert op to check that bytes are GIF decodable
is_gif = math_ops.equal(substr, b'\x47\x49\x46\x38', name='is_gif')
decode_msg = 'Unable to decode bytes as JPEG, PNG, or GIF'
assert_decode = control_flow_ops.Assert(is_gif, [decode_msg])
# Create assert to make sure that channels is not set to 1
# Already checked above that channels is in (None, 0, 1, 3)
gif_channels = 0 if channels is None else channels
good_channels = math_ops.not_equal(gif_channels, 1, name='check_channels')
channels_msg = 'Channels must be in (None, 0, 3) when decoding GIF images'
assert_channels = control_flow_ops.Assert(good_channels, [channels_msg])
with ops.control_dependencies([assert_decode, assert_channels]):
return gen_image_ops.decode_gif(contents)

def _png():
return gen_image_ops.decode_png(contents, channels)

def check_png():
is_png = math_ops.equal(substr, b'\211PNG', name='is_png')
return control_flow_ops.cond(is_png, _png, _gif, name='cond_png')

def _jpeg():
return gen_image_ops.decode_jpeg(contents, channels)

is_jpeg = math_ops.equal(substr, b'\xff\xd8\xff\xe0', name='is_jpeg')
return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg')

0 comments on commit db49056

Please sign in to comment.