forked from GantMan/nsfw_model
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request GantMan#14 from GantMan/cleanup_structure
Cleanup structure - Allows for more model types
- Loading branch information
Showing
9 changed files
with
276 additions
and
183 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
*.h5 | ||
logs/ | ||
.vscode/ | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from keras.callbacks import ModelCheckpoint, TensorBoard, LearningRateScheduler | ||
from time import time | ||
|
||
# Slow down training deeper into dataset | ||
def schedule(epoch): | ||
if epoch < 6: | ||
# Warmup model first | ||
return .0000032 | ||
elif epoch < 12: | ||
return .01 | ||
elif epoch < 20: | ||
return .002 | ||
elif epoch < 40: | ||
return .0004 | ||
elif epoch < 60: | ||
return .00008 | ||
elif epoch < 80: | ||
return .000016 | ||
elif epoch < 95: | ||
return .0000032 | ||
else: | ||
return .0000009 | ||
|
||
|
||
def make_callbacks(weights_file): | ||
# checkpoint | ||
filepath = weights_file | ||
checkpoint = ModelCheckpoint( | ||
filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max') | ||
|
||
# Update info | ||
tensorboard = TensorBoard(log_dir="logs/{}".format(time())) | ||
|
||
# learning rate schedule | ||
lr_scheduler = LearningRateScheduler(schedule) | ||
|
||
# all the goodies | ||
return [lr_scheduler, checkpoint, tensorboard] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Config | ||
SIZES = { | ||
'basic': 299 | ||
} | ||
|
||
NUM_CHANNELS = 3 | ||
NUM_CLASSES = 5 | ||
GENERATOR_BATCH_SIZE = 32 | ||
TOTAL_EPOCHS = 100 | ||
STEPS_PER_EPOCH = 500 | ||
VALIDATION_STEPS = 50 | ||
BASE_DIR = 'D:\\nswf_model_training_data\\data' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import os | ||
from keras.preprocessing.image import ImageDataGenerator | ||
import constants | ||
|
||
train_datagen = ImageDataGenerator( | ||
rescale=1./255, | ||
rotation_range=30, | ||
width_shift_range=0.2, | ||
height_shift_range=0.2, | ||
shear_range=0.2, | ||
zoom_range=0.2, | ||
channel_shift_range=20, | ||
horizontal_flip=True, | ||
fill_mode='nearest' | ||
) | ||
|
||
# Validation data should not be modified | ||
validation_datagen = ImageDataGenerator( | ||
rescale=1./255 | ||
) | ||
|
||
train_dir = os.path.join(constants.BASE_DIR, 'train') | ||
test_dir = os.path.join(constants.BASE_DIR, 'test') | ||
|
||
def create_generators(height, width): | ||
train_generator = train_datagen.flow_from_directory( | ||
train_dir, | ||
target_size=(height, width), | ||
class_mode='categorical', | ||
batch_size=constants.GENERATOR_BATCH_SIZE | ||
) | ||
|
||
validation_generator = validation_datagen.flow_from_directory( | ||
test_dir, | ||
target_size=(height, width), | ||
class_mode='categorical', | ||
batch_size=constants.GENERATOR_BATCH_SIZE | ||
) | ||
|
||
return[train_generator, validation_generator] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import os | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras.backend import clear_session | ||
from keras.optimizers import SGD | ||
from pathlib import Path | ||
from keras.models import Sequential, Model, load_model | ||
|
||
# reusable stuff | ||
import constants | ||
import callbacks | ||
import generators | ||
|
||
# No kruft plz | ||
clear_session() | ||
|
||
# Config | ||
height = constants.SIZES['basic'] | ||
width = height | ||
weights_file = "weights.best_inception" + str(height) + ".hdf5" | ||
|
||
print ('Starting from last full model run') | ||
model = load_model("nsfw." + str(width) + "x" + str(height) + ".h5") | ||
|
||
# Unlock a few layers deep in Inception v3 | ||
model.trainable = False | ||
set_trainable = False | ||
for layer in model.layers: | ||
if layer.name == 'conv2d_56': | ||
set_trainable = True | ||
if set_trainable: | ||
layer.trainable = True | ||
else: | ||
layer.trainable = False | ||
|
||
# Let's see it | ||
print('Summary') | ||
print(model.summary()) | ||
|
||
# Load checkpoint if one is found | ||
if os.path.exists(weights_file): | ||
print ("loading ", weights_file) | ||
model.load_weights(weights_file) | ||
|
||
# Get all model callbacks | ||
callbacks_list = callbacks.make_callbacks(weights_file) | ||
|
||
print('Compile model') | ||
opt = SGD(momentum=.9) | ||
model.compile( | ||
loss='categorical_crossentropy', | ||
optimizer=opt, | ||
metrics=['accuracy'] | ||
) | ||
|
||
# Get training/validation data via generators | ||
train_generator, validation_generator = generators.create_generators(height, width) | ||
|
||
print('Start training!') | ||
history = model.fit_generator( | ||
train_generator, | ||
callbacks=callbacks_list, | ||
epochs=constants.TOTAL_EPOCHS, | ||
steps_per_epoch=constants.STEPS_PER_EPOCH, | ||
shuffle=True, | ||
# having crazy threading issues | ||
# set workers to zero if you see an error like: | ||
# `freeze_support()` | ||
workers=0, | ||
use_multiprocessing=True, | ||
validation_data=validation_generator, | ||
validation_steps=constants.VALIDATION_STEPS | ||
) | ||
|
||
# Save it for later | ||
print('Saving Model') | ||
model.save("nsfw." + str(width) + "x" + str(height) + ".h5") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import os | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras.backend import clear_session | ||
from keras.optimizers import SGD | ||
from pathlib import Path | ||
from keras.applications import InceptionV3 | ||
from keras.models import Sequential, Model, load_model | ||
from keras.layers import Dense, Dropout, Flatten, AveragePooling2D | ||
from keras import initializers, regularizers | ||
|
||
# reusable stuff | ||
import constants | ||
import callbacks | ||
import generators | ||
|
||
# No kruft plz | ||
clear_session() | ||
|
||
# Config | ||
height = constants.SIZES['basic'] | ||
width = height | ||
weights_file = "weights.best_inception" + str(height) + ".hdf5" | ||
|
||
conv_base = InceptionV3( | ||
weights='imagenet', | ||
include_top=False, | ||
input_shape=(height, width, constants.NUM_CHANNELS) | ||
) | ||
|
||
# First time run, no unlocking | ||
conv_base.trainable = False | ||
|
||
# Let's see it | ||
print('Summary') | ||
print(conv_base.summary()) | ||
|
||
# Let's construct that top layer replacement | ||
x = conv_base.output | ||
x = AveragePooling2D(pool_size=(8, 8))(x) | ||
x - Dropout(0.4)(x) | ||
x = Flatten()(x) | ||
x = Dense(256, activation='relu', kernel_initializer=initializers.he_normal(seed=None), kernel_regularizer=regularizers.l2(.0005))(x) | ||
x = Dropout(0.5)(x) | ||
# Essential to have another layer for better accuracy | ||
x = Dense(128,activation='relu', kernel_initializer=initializers.he_normal(seed=None))(x) | ||
x = Dropout(0.25)(x) | ||
predictions = Dense(constants.NUM_CLASSES, kernel_initializer="glorot_uniform", activation='softmax')(x) | ||
|
||
print('Stacking New Layers') | ||
model=Model(inputs = conv_base.input, outputs=predictions) | ||
|
||
# Load checkpoint if one is found | ||
if os.path.exists(weights_file): | ||
print ("loading ", weights_file) | ||
model.load_weights(weights_file) | ||
|
||
# Get all model callbacks | ||
callbacks_list = callbacks.make_callbacks(weights_file) | ||
|
||
print('Compile model') | ||
# originally adam, but research says SGD with scheduler | ||
# opt = Adam(lr=0.001, amsgrad=True) | ||
opt = SGD(momentum=.9) | ||
model.compile( | ||
loss='categorical_crossentropy', | ||
optimizer=opt, | ||
metrics=['accuracy'] | ||
) | ||
|
||
# Get training/validation data via generators | ||
train_generator, validation_generator = generators.create_generators(height, width) | ||
|
||
print('Start training!') | ||
history = model.fit_generator( | ||
train_generator, | ||
callbacks=callbacks_list, | ||
epochs=constants.TOTAL_EPOCHS, | ||
steps_per_epoch=constants.STEPS_PER_EPOCH, | ||
shuffle=True, | ||
# having crazy threading issues | ||
# set workers to zero if you see an error like: | ||
# `freeze_support()` | ||
workers=0, | ||
use_multiprocessing=True, | ||
validation_data=validation_generator, | ||
validation_steps=constants.VALIDATION_STEPS | ||
) | ||
|
||
# Save it for later | ||
print('Saving Model') | ||
model.save("nsfw." + str(width) + "x" + str(height) + ".h5") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.