-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
2e070bc
commit bfc18f3
Showing
1 changed file
with
117 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
""" | ||
Module that is used to train the CNNs that power Indoor Scene Detector. | ||
Notes: | ||
- We ran this script in a Kaggle notebook in order to utilize the free GPU. | ||
- The paths to the data set below are absolute paths on Kaggle. | ||
- The beta-version of of Indoor Scene Detector is trained on a subset of this data set | ||
(10/67 classes). We are currently developing support for all 67 classes. | ||
- The "training" module on Kaggle is saved in a Kaggle utility scripts called | ||
"indoorscenes_training", so the import statement looks slightly different | ||
within the Kaggle notebook.. | ||
TODO: Add Kaggle notebook link | ||
""" | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from torchvision import transforms, datasets | ||
from torch.utils.data import random_split, DataLoader | ||
from training import ( | ||
SimpleCNN, | ||
train_model, | ||
get_custom_alexnet, | ||
get_custom_densenet121, | ||
get_custom_resnet18, | ||
) | ||
|
||
# random seed | ||
SEED = 42 | ||
torch.manual_seed(SEED) | ||
|
||
# img and batch sizes | ||
IMG_SIZE = (256, 256) | ||
BATCH_SIZE = 64 | ||
|
||
# train, val, test split | ||
SPLIT = [4500, 600, 561] | ||
|
||
# model training variables | ||
CRITERION = nn.CrossEntropyLoss() | ||
EPOCHS = 25 | ||
PATIENCE = 5 | ||
|
||
# set device | ||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
||
# kaggle absolute path | ||
# DATA_PATH = "../input/indoor-scenes-cvpr-2019/indoorCVPR_09/Images/" | ||
DATA_PATH = "../input/top10indoorscenes/reduced-indoor-scenes/images/" | ||
|
||
TRANSFORMS = transforms.Compose( | ||
[ | ||
transforms.Resize(IMG_SIZE), | ||
transforms.ToTensor(), | ||
] | ||
) | ||
|
||
# load data, split, get loaders | ||
dataset = datasets.ImageFolder(DATA_PATH, transform=TRANSFORMS) | ||
n_classes = len(dataset.classes) | ||
train_set, val_set, test_set = random_split(dataset, SPLIT) | ||
trainloader = DataLoader( | ||
train_set, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True | ||
) | ||
validloader = DataLoader(val_set, BATCH_SIZE, num_workers=2, pin_memory=True) | ||
|
||
# models to train | ||
cnn_models = { | ||
"Simple_CNN": SimpleCNN(), | ||
"ResNet18_Tuned": get_custom_resnet18( | ||
n_classes=n_classes, pretrained=True, freeze=False | ||
), | ||
"AlexNet_Tuned": get_custom_alexnet( | ||
n_classes=n_classes, pretrained=True, freeze=False | ||
), | ||
"DenseNet121_Tuned": get_custom_densenet121( | ||
n_classes=n_classes, pretrained=True, freeze=False | ||
), | ||
} | ||
|
||
# to track all training and validation losses and accuracies | ||
train_losses, train_accs = {}, {} | ||
valid_losses, valid_accs = {}, {} | ||
|
||
# train models | ||
for name, model in cnn_models.items(): | ||
print(f"{name} training started...") | ||
print( | ||
"====================================================================================================" | ||
) | ||
|
||
model.to(DEVICE) | ||
optimizer = torch.optim.Adam(model.parameters()) | ||
|
||
tuned_model, train_loss, train_acc, valid_loss, valid_acc = train_model( | ||
model=model, | ||
device=DEVICE, | ||
criterion=CRITERION, | ||
optimizer=optimizer, | ||
trainloader=trainloader, | ||
validloader=validloader, | ||
epochs=EPOCHS, | ||
patience=PATIENCE, | ||
save=True, | ||
name=f"{name}", | ||
) | ||
|
||
cnn_models[name] = tuned_model | ||
train_losses[name] = train_loss | ||
train_accs[name] = train_acc | ||
valid_losses[name] = valid_loss | ||
valid_accs[name] = valid_acc | ||
|
||
print( | ||
"====================================================================================================\n" | ||
) |