Skip to content

Commit

Permalink
Refactor image preprocessing iterators to subclass Sequence. (keras-t…
Browse files Browse the repository at this point in the history
…eam#7853)

* Refactor image preprocessing iterators to subclass Sequence.

* Add more tests for image preprocessing Sequences
  • Loading branch information
fchollet authored Sep 8, 2017
1 parent 19862b0 commit ebd0f08
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 65 deletions.
16 changes: 10 additions & 6 deletions examples/cifar10_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@
batch_size=batch_size),
steps_per_epoch=x_train.shape[0] // batch_size,
epochs=epochs,
validation_data=(x_test, y_test))
validation_data=(x_test, y_test),
workers=4)

# Save model and weights
if not os.path.isdir(save_dir):
Expand All @@ -129,14 +130,17 @@

# Evaluate model with test data set and share sample prediction results
evaluation = model.evaluate_generator(datagen.flow(x_test, y_test,
batch_size=batch_size),
steps=x_test.shape[0] // batch_size)

batch_size=batch_size,
shuffle=False),
steps=x_test.shape[0] // batch_size,
workers=4)
print('Model Accuracy = %.2f' % (evaluation[1]))

predict_gen = model.predict_generator(datagen.flow(x_test, y_test,
batch_size=batch_size),
steps=x_test.shape[0] // batch_size)
batch_size=batch_size,
shuffle=False),
steps=x_test.shape[0] // batch_size,
workers=4)

for predict_index, predicted_y in enumerate(predict_gen):
actual_label = labels['label_names'][np.argmax(y_test[predict_index])]
Expand Down
116 changes: 74 additions & 42 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from functools import partial

from .. import backend as K
from ..utils.data_utils import Sequence

try:
from PIL import Image as pil_image
Expand Down Expand Up @@ -684,7 +685,7 @@ def fit(self, x,
self.principal_components = np.dot(np.dot(u, np.diag(1. / np.sqrt(s + self.zca_epsilon))), u.T)


class Iterator(object):
class Iterator(Sequence):
"""Abstract base class for image data iterators.
# Arguments
Expand All @@ -697,36 +698,60 @@ class Iterator(object):
def __init__(self, n, batch_size, shuffle, seed):
self.n = n
self.batch_size = batch_size
self.seed = seed
self.shuffle = shuffle
self.batch_index = 0
self.total_batches_seen = 0
self.lock = threading.Lock()
self.index_generator = self._flow_index(n, batch_size, shuffle, seed)
self.index_array = None
self.index_generator = self._flow_index()

def _set_index_array(self):
self.index_array = np.arange(self.n)
if self.shuffle:
self.index_array = np.random.permutation(self.n)

def __getitem__(self, idx):
if idx >= len(self):
raise ValueError('Asked to retrieve element {idx}, '
'but the Sequence '
'has length {length}'.format(idx=idx,
length=len(self)))
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
self.total_batches_seen += 1
if self.index_array is None:
self._set_index_array()
index_array = self.index_array[self.batch_size * idx:
self.batch_size * (idx + 1)]
return self._get_batches_of_transformed_samples(index_array)

def __len__(self):
return int(np.ceil(self.n / float(self.batch_size)))

def on_epoch_end(self):
self._set_index_array()

def reset(self):
self.batch_index = 0

def _flow_index(self, n, batch_size=32, shuffle=False, seed=None):
def _flow_index(self):
# Ensure self.batch_index is 0.
self.reset()
while 1:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
if self.seed is not None:
np.random.seed(self.seed + self.total_batches_seen)
if self.batch_index == 0:
index_array = np.arange(n)
if shuffle:
index_array = np.random.permutation(n)
self._set_index_array()

current_index = (self.batch_index * batch_size) % n
if n > current_index + batch_size:
current_batch_size = batch_size
current_index = (self.batch_index * self.batch_size) % self.n
if self.n > current_index + self.batch_size:
self.batch_index += 1
else:
current_batch_size = n - current_index
self.batch_index = 0
self.total_batches_seen += 1
yield (index_array[current_index: current_index + current_batch_size],
current_index, current_batch_size)
yield self.index_array[current_index:
current_index + self.batch_size]

def __iter__(self):
# Needed if we want to do something like:
Expand Down Expand Up @@ -796,29 +821,19 @@ def __init__(self, x, y, image_data_generator,
self.save_format = save_format
super(NumpyArrayIterator, self).__init__(x.shape[0], batch_size, shuffle, seed)

def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
# Keeps under lock only the mechanism which advances
# the indexing of each batch.
with self.lock:
index_array, current_index, current_batch_size = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
batch_x = np.zeros(tuple([current_batch_size] + list(self.x.shape)[1:]), dtype=K.floatx())
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros(tuple([len(index_array)] + list(self.x.shape)[1:]),
dtype=K.floatx())
for i, j in enumerate(index_array):
x = self.x[j]
x = self.image_data_generator.random_transform(x.astype(K.floatx()))
x = self.image_data_generator.standardize(x)
batch_x[i] = x
if self.save_to_dir:
for i in range(current_batch_size):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
index=j,
hash=np.random.randint(1e4),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
Expand All @@ -827,6 +842,20 @@ def next(self):
batch_y = self.y[index_array]
return batch_x, batch_y

def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
# Keeps under lock only the mechanism which advances
# the indexing of each batch.
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)


def _count_valid_files_in_directory(directory, white_list_formats, follow_links):
"""Count files with extension in `white_list_formats` contained in a directory.
Expand Down Expand Up @@ -1013,17 +1042,8 @@ def __init__(self, directory, image_data_generator,
pool.join()
super(DirectoryIterator, self).__init__(self.samples, batch_size, shuffle, seed)

def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
with self.lock:
index_array, current_index, current_batch_size = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
batch_x = np.zeros((current_batch_size,) + self.image_shape, dtype=K.floatx())
def _get_batches_of_transformed_samples(self, index_array):
batch_x = np.zeros((len(index_array),) + self.image_shape, dtype=K.floatx())
grayscale = self.color_mode == 'grayscale'
# build batch of image data
for i, j in enumerate(index_array):
Expand All @@ -1037,10 +1057,10 @@ def next(self):
batch_x[i] = x
# optionally save augmented images to disk for debugging purposes
if self.save_to_dir:
for i in range(current_batch_size):
for i, j in enumerate(index_array):
img = array_to_img(batch_x[i], self.data_format, scale=True)
fname = '{prefix}_{index}_{hash}.{format}'.format(prefix=self.save_prefix,
index=current_index + i,
index=j,
hash=np.random.randint(1e4),
format=self.save_format)
img.save(os.path.join(self.save_to_dir, fname))
Expand All @@ -1058,3 +1078,15 @@ def next(self):
else:
return batch_x
return batch_x, batch_y

def next(self):
"""For python 2.x.
# Returns
The next batch.
"""
with self.lock:
index_array = next(self.index_generator)
# The transformation of images is not under thread lock
# so it can be done in parallel
return self._get_batches_of_transformed_samples(index_array)
13 changes: 7 additions & 6 deletions keras/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,19 +313,20 @@ class Sequence(object):
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.X,self.y = x_set,y_set
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return len(self.X) // self.batch_size
return len(self.x) // self.batch_size
def __getitem__(self,idx):
batch_x = self.X[idx*self.batch_size:(idx+1)*self.batch_size]
batch_y = self.y[idx*self.batch_size:(idx+1)*self.batch_size]
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200,200))
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
```
"""
Expand Down
70 changes: 59 additions & 11 deletions tests/keras/preprocessing/image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@
import os


class TestImage:
class TestImage(object):

def setup_class(cls):
img_w = img_h = 20
cls.img_w = cls.img_h = 20
rgb_images = []
gray_images = []
for n in range(8):
bias = np.random.rand(img_w, img_h, 1) * 64
variance = np.random.rand(img_w, img_h, 1) * (255 - 64)
imarray = np.random.rand(img_w, img_h, 3) * variance + bias
bias = np.random.rand(cls.img_w, cls.img_h, 1) * 64
variance = np.random.rand(cls.img_w, cls.img_h, 1) * (255 - 64)
imarray = np.random.rand(cls.img_w, cls.img_h, 3) * variance + bias
im = Image.fromarray(imarray.astype('uint8')).convert('RGB')
rgb_images.append(im)

imarray = np.random.rand(img_w, img_h, 1) * variance + bias
imarray = np.random.rand(cls.img_w, cls.img_h, 1) * variance + bias
im = Image.fromarray(imarray.astype('uint8').squeeze()).convert('L')
gray_images.append(im)

Expand Down Expand Up @@ -53,10 +53,43 @@ def test_image_data_generator(self, tmpdir):
generator.fit(images, augment=True)

for x, y in generator.flow(images, np.arange(images.shape[0]),
shuffle=True, save_to_dir=str(tmpdir)):
assert x.shape[1:] == images.shape[1:]
shuffle=False, save_to_dir=str(tmpdir),
batch_size=3):
assert x.shape == images[:3].shape
assert list(y) == [0, 1, 2]
break

# Test with `shuffle=True`
for x, y in generator.flow(images, np.arange(images.shape[0]),
shuffle=True, save_to_dir=str(tmpdir),
batch_size=3):
assert x.shape == images[:3].shape
# Check that the sequence is shuffled.
assert list(y) != [0, 1, 2]
break

# Test `flow` behavior as Sequence
seq = generator.flow(images, np.arange(images.shape[0]),
shuffle=False, save_to_dir=str(tmpdir),
batch_size=3)
assert len(seq) == images.shape[0] // 3 + 1
x, y = seq[0]
assert x.shape == images[:3].shape
assert list(y) == [0, 1, 2]

# Test with `shuffle=True`
seq = generator.flow(images, np.arange(images.shape[0]),
shuffle=True, save_to_dir=str(tmpdir),
batch_size=3, seed=123)
x, y = seq[0]
# Check that the sequence is shuffled.
assert list(y) != [0, 1, 2]

# `on_epoch_end` should reshuffle the sequence.
seq.on_epoch_end()
x2, y2 = seq[0]
assert list(y) != list(y2)

def test_image_data_generator_invalid_data(self):
generator = image.ImageDataGenerator(
featurewise_center=True,
Expand Down Expand Up @@ -140,16 +173,31 @@ def test_directory_iterator(self, tmpdir):
dir_iterator = generator.flow_from_directory(str(tmpdir))

# check number of classes and images
assert(len(dir_iterator.class_indices) == num_classes)
assert(len(dir_iterator.classes) == count)
assert(sorted(dir_iterator.filenames) == sorted(filenames))
assert len(dir_iterator.class_indices) == num_classes
assert len(dir_iterator.classes) == count
assert sorted(dir_iterator.filenames) == sorted(filenames)

# Test invalid use cases
with pytest.raises(ValueError):
generator.flow_from_directory(str(tmpdir), color_mode='cmyk')
with pytest.raises(ValueError):
generator.flow_from_directory(str(tmpdir), class_mode='output')

# Test usage as Sequence
generator = image.ImageDataGenerator()
dir_seq = generator.flow_from_directory(str(tmpdir),
target_size=(26, 26),
color_mode='rgb',
batch_size=3,
class_mode='categorical')
assert len(dir_seq) == count // 3 + 1
x1, y1 = dir_seq[1]
assert x1.shape == (3, 26, 26, 3)
assert y1.shape == (3, num_classes)
x1, y1 = dir_seq[5]
with pytest.raises(ValueError):
x1, y1 = dir_seq[9]

def test_directory_iterator_class_mode_input(self, tmpdir):
tmpdir.join('class-1').mkdir()

Expand Down

0 comments on commit ebd0f08

Please sign in to comment.