Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make ImageDataGenerator behaviour fully seedable/repeatable #3751

Merged
merged 2 commits into from
Sep 22, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion docs/templates/preprocessing/image.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,14 @@ Generate batches of tensor image data with real-time data augmentation. The data
- __X__: sample data.
- __augment__: Boolean (default: False). Whether to fit on randomly augmented samples.
- __rounds__: int (default: 1). If augment, how many augmentation passes over the data to use.
- __seed__: int (default: None). Random seed.
- __flow(X, y)__: Takes numpy data & label arrays, and generates batches of augmented/normalized data. Yields batches indefinitely, in an infinite loop.
- __Arguments__:
- __X__: data.
- __y__: labels.
- __batch_size__: int (default: 32).
- __shuffle__: boolean (defaut: True).
- __seed__: int (default: None).
- __save_to_dir__: None or str (default: None). This allows you to optimally specify a directory to which to save the augmented pictures being generated (useful for visualizing what you are doing).
- __save_prefix__: str (default: `''`). Prefix to use for filenames of saved pictures (only relevant if `save_to_dir` is set).
- __save_format__: one of "png", "jpeg" (only relevant if `save_to_dir` is set). Default: "jpeg".
Expand All @@ -77,7 +79,7 @@ Generate batches of tensor image data with real-time data augmentation. The data
- __class_mode__: one of "categorical", "binary", "sparse" or None. Default: "categorical". Determines the type of label arrays that are returned: "categorical" will be 2D one-hot encoded labels, "binary" will be 1D binary labels, "sparse" will be 1D integer labels. If None, no labels are returned (the generator will only yield batches of image data, which is useful to use `model.predict_generator()`, `model.evaluate_generator()`, etc.).
- __batch_size__: size of the batches of data (default: 32).
- __shuffle__: whether to shuffle the data (default: True)
- __seed__: optional random seed for shuffling.
- __seed__: optional random seed for shuffling and transformations.
- __save_to_dir__: None or str (default: None). This allows you to optimally specify a directory to which to save the augmented pictures being generated (useful for visualizing what you are doing).
- __save_prefix__: str. Prefix to use for filenames of saved pictures (only relevant if `save_to_dir` is set).
- __save_format__: one of "png", "jpeg" (only relevant if `save_to_dir` is set). Default: "jpeg".
Expand Down Expand Up @@ -151,3 +153,38 @@ model.fit_generator(
validation_data=validation_generator,
nb_val_samples=800)
```

Example of using it to transform images and masks together.

```python

# we create two instances with the same arguments
data_gen_args = dict(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=90.,
width_shift_range=0.1,
height_shift_range=0.1,
zoom_range=0.2
)
generator_i = ImageDataGenerator(**data_gen_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use explicit variable names in examples (e.g. image_generator)

generator_m = ImageDataGenerator(**data_gen_args)

# We need to provide the same seed to the fit and flow methods
seed = 1
fake_y = np.zeroes(images.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't need labels, you should configure the generator not to use labels (class_mode=None).

generator_i.fit(images, augment=True, seed=seed)
generator_m.fit(masks, augment=True, seed=seed)
flow_images = generator_i.flow(images, fake_y, batch_size=2, seed=seed)
flow_masks = generator_m.flow(masks, fake_y, batch_size=2, seed=seed)

# combine the two generators into one which gives the first from each
def dual_gen(flow_images,flow_masks):
for [X,_],[y,_] in zip(flow_images,flow_masks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, would be much more elegant without the fake labels.

yield X,y
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8


model.fit_generator(
dual_gen(flow_images,flow_masks),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8

samples_per_epoch=2000,
nb_epoch=50)
```
7 changes: 5 additions & 2 deletions keras/preprocessing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ def fit(self, X,
how many augmentation passes to do over the data
seed: random seed.
'''
if seed is not None:
np.random.seed(seed)

X = np.copy(X)
if augment:
aX = np.zeros(tuple([rounds * X.shape[0]] + list(X.shape)[1:]))
Expand Down Expand Up @@ -431,11 +434,11 @@ def _flow_index(self, N, batch_size=32, shuffle=False, seed=None):
# ensure self.batch_index is 0
self.reset()
while 1:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
if self.batch_index == 0:
index_array = np.arange(N)
if shuffle:
if seed is not None:
np.random.seed(seed + self.total_batches_seen)
index_array = np.random.permutation(N)

current_index = (self.batch_index * batch_size) % N
Expand Down
49 changes: 49 additions & 0 deletions tests/keras/preprocessing/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,55 @@ def test_img_flip(self):
assert ((potentially_flipped_x == x).all() or
(potentially_flipped_x == flip_axis(x, col_index)).all())

def test_dual_generators(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test is useful; please remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My logic is that this test will break if any future PR break's repeatability, if you have time could you expand on why you don't think that's worth a test? In the mean time I'll remove it.

'''
Should be able to run ImageDataGenerator on multiple arrays of images
and get identical transforms (as long as we provide the same seed).
This lets us transform images and masks together or other use cases.
'''
for test_images in self.all_test_images:
img_list = []
for im in test_images:
img_list.append(img_to_array(im)[None, ...])

X = np.vstack(img_list)
y = np.arange(X.shape[0])

seed = 1
batch_size = 2
data_gen_args = dict(
featurewise_center=True,
samplewise_center=True,
featurewise_std_normalization=True,
samplewise_std_normalization=True,
zca_whitening=True,
rotation_range=90.,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.5,
zoom_range=0.2,
channel_shift_range=0.01,
fill_mode='nearest',
cval=0.5,
horizontal_flip=True,
vertical_flip=True)

generator1 = ImageDataGenerator(**data_gen_args)
generator2 = ImageDataGenerator(**data_gen_args)

generator1.fit(X, augment=True, seed=seed)
generator2.fit(X, augment=True, seed=seed)

flow1 = generator1.flow(X, y, batch_size=batch_size, seed=seed)
flow2 = generator2.flow(X, y, batch_size=batch_size, seed=seed)

for i in range(10):
X1, y1 = next(flow1)
X2, y2 = next(flow2)
for b in range(batch_size):
assert (X1 == X2).all()
assert (y1 == y2).all()


if __name__ == '__main__':
pytest.main([__file__])