Skip to content

Commit

Permalink
updated training
Browse files Browse the repository at this point in the history
  • Loading branch information
GantMan committed Feb 5, 2019
1 parent 447c23d commit 2ff414b
Showing 1 changed file with 38 additions and 24 deletions.
62 changes: 38 additions & 24 deletions train_inception_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from time import time
from keras.preprocessing.image import ImageDataGenerator
from keras.backend import clear_session
from keras.optimizers import Adam
from keras.optimizers import SGD
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
Expand All @@ -11,7 +11,7 @@
from keras.applications import InceptionV3
from keras.models import Sequential, Model, load_model
from keras.layers import Dense, Dropout, Flatten, AveragePooling2D
from keras.callbacks import ModelCheckpoint, TensorBoard
from keras.callbacks import ModelCheckpoint, TensorBoard, LearningRateScheduler
from keras import initializers, regularizers

# No kruft plz
Expand All @@ -37,22 +37,9 @@
input_shape=(height, width, num_channels)
)


# base_model = InceptionV3(weights='imagenet', include_top=False, input_tensor=Input(shape=(299, 299, 3)))
# x = base_model.output
# x = AveragePooling2D(pool_size=(8, 8))(x)
# x = Dropout(.4)(x)
# x = Flatten()(x)
# predictions = Dense(n_classes, init='glorot_uniform', W_regularizer=l2(.0005), activation='softmax')(x)

# model = Model(input=base_model.input, output=predictions)


# First time run, no unlocking
conv_base.trainable = False
# Let's unlock trainable layers in conv_base
# conv_base.trainable = True

#conv_base.trainable = False
# Let's unlock trainable layers in conv_base by name
# set_trainable = False
# for layer in conv_base.layers:
# if layer.name == 'block14_sepconv1':
Expand All @@ -61,6 +48,12 @@
# layer.trainable = True
# else:
# layer.trainable = False
# Let's unlock by layer level
for layer in conv_base.layers[:172]:
layer.trainable = False
for layer in conv_base.layers[172:]:
layer.trainable = True


# Let's see it
print('Summary')
Expand All @@ -73,6 +66,7 @@
x = Flatten()(x)
x = Dense(256, activation='relu', kernel_initializer=initializers.he_normal(seed=None), kernel_regularizer=regularizers.l2(.0005))(x)
x = Dropout(0.5)(x)
# I considered this since it will be hard to overfit a huge dataset, but simpler is better
# x = Dense(128,activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x)
# x = Dropout(0.25)(x)
predictions = Dense(num_classes, kernel_initializer="glorot_uniform", activation='softmax')(x)
Expand All @@ -93,13 +87,32 @@
# Update info
tensorboard = TensorBoard(log_dir="logs/{}".format(time()))

callbacks_list = [checkpoint, tensorboard]
# Slow down training deeper into dataset
def schedule(epoch):
if epoch < 15:
return .01
elif epoch < 28:
return .002
elif epoch < 68:
return .0004
if epoch < 78:
return .00008
elif epoch < 88:
return .000016
else:
return .0000032
lr_scheduler = LearningRateScheduler(schedule)


callbacks_list = [lr_scheduler, checkpoint, tensorboard]

print('Compile model')
adam = Adam(lr=0.001, amsgrad=True)
# originally adam, but research says SGD with scheduler
# opt = Adam(lr=0.001, amsgrad=True)
opt = SGD(lr=.01, momentum=.9)
model.compile(
loss='categorical_crossentropy',
optimizer=adam,
optimizer=opt,
metrics=['accuracy']
)

Expand All @@ -110,7 +123,7 @@
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
channel_shift_range=30,
channel_shift_range=20,
horizontal_flip=True,
fill_mode='nearest'
)
Expand Down Expand Up @@ -141,22 +154,23 @@

# Comment in this line if you're looking to reload the last model for training
# Essentially, not taking the best validation weights but to add more epochs
# print ('Starting from last full model run')
# model = load_model("nsfw." + str(width) + "x" + str(height) + ".h5")

print('Start training!')
history = model.fit_generator(
train_generator,
callbacks=callbacks_list,
epochs=50,
steps_per_epoch=100,
epochs=100,
steps_per_epoch=500,
shuffle=True,
# having crazy threading issues
# set workers to zero if you see an error like:
# `freeze_support()`
workers=0,
use_multiprocessing=True,
validation_data=validation_generator,
validation_steps=50
validation_steps=100
)

acc = history.history['acc']
Expand Down

0 comments on commit 2ff414b

Please sign in to comment.