Skip to content

Commit

Permalink
Randomly transform with 50% chance
Browse files Browse the repository at this point in the history
  • Loading branch information
DubiousCactus committed May 8, 2019
1 parent a0191ed commit 413c70f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 8 deletions.
2 changes: 1 addition & 1 deletion cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _main():
# Generate training data with no real-time augmentation
# train_datagen = utils.fit_flow_from_directory(rescale=1./255)
train_datagen = utils.DroneDataGenerator(rescale=1./255,
channel_shift_range=0.15)
channel_shift_range=0.1)

config = {
'featurewise_center': True,
Expand Down
6 changes: 3 additions & 3 deletions cnn_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def resnet8(img_width, img_height, img_channels, output_dim,

# h_noise = GaussianNoise(0.01)(x)
h1 = Dense(500, activation='relu', kernel_regularizer=regularizers.l2(0.001))(x)
h1 = Dropout(0.3)(h1)
h2 = Dense(500, activation='relu', kernel_regularizer=regularizers.l2(0.001))(h1)
h2 = Dropout(0.2)(h2)
h1 = Dropout(0.4)(h1)
h2 = Dense(100, activation='relu', kernel_regularizer=regularizers.l2(0.001))(h1)
h2 = Dropout(0.3)(h2)
# Gate localization
localization = Dense(output_dim, activation='softmax')(h2) # Logits + Softmax

Expand Down
2 changes: 1 addition & 1 deletion img_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def load_img(path, grayscale=False):
else:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

return np.asarray(img, dtype=np.uint8)
return np.asarray(img, dtype=np.float64)


def random_channel_shift(img, shiftFactor, channel_axis=2):
Expand Down
8 changes: 5 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,16 @@ def _get_batches_of_transformed_samples(self, index_array) :
x = img_utils.load_img(os.path.join(self.directory, fname),
grayscale=grayscale)

if self.image_data_generator.channelShiftFactor > 0:
# 50% chances of transforming the image
shifting = np.random.rand() <= 0.5
if shifting and self.image_data_generator.channelShiftFactor > 0:
transformed_x = img_utils.random_channel_shift(np.copy(x),
self.image_data_generator.channelShiftFactor)
else:
transformed_x = x

if self.image_data_generator.channelShiftFactor > 0 and self.saved_transforms < 50:
Image.fromarray(x, "RGB").save(os.path.join(self.directory,
if shifting and self.image_data_generator.channelShiftFactor > 0 and self.saved_transforms < 50:
Image.fromarray(x.astype(np.uint8), "RGB").save(os.path.join(self.directory,
"img_transforms",
"original_{}.jpg".format(self.saved_transforms)))
Image.fromarray(transformed_x.astype(np.uint8), "RGB").save(os.path.join(self.directory,
Expand Down

0 comments on commit 413c70f

Please sign in to comment.