Skip to content

Commit

Permalink
feat: update kaggle training script
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovandenhooff committed Apr 25, 2022
1 parent 2e070bc commit bfc18f3
Showing 1 changed file with 117 additions and 0 deletions.
117 changes: 117 additions & 0 deletions api/ml/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
Module that is used to train the CNNs that power Indoor Scene Detector.
Notes:
- We ran this script in a Kaggle notebook in order to utilize the free GPU.
- The paths to the data set below are absolute paths on Kaggle.
- The beta-version of of Indoor Scene Detector is trained on a subset of this data set
(10/67 classes). We are currently developing support for all 67 classes.
- The "training" module on Kaggle is saved in a Kaggle utility scripts called
"indoorscenes_training", so the import statement looks slightly different
within the Kaggle notebook..
TODO: Add Kaggle notebook link
"""

import torch
import torch.nn as nn

from torchvision import transforms, datasets
from torch.utils.data import random_split, DataLoader
from training import (
SimpleCNN,
train_model,
get_custom_alexnet,
get_custom_densenet121,
get_custom_resnet18,
)

# random seed
SEED = 42
torch.manual_seed(SEED)

# img and batch sizes
IMG_SIZE = (256, 256)
BATCH_SIZE = 64

# train, val, test split
SPLIT = [4500, 600, 561]

# model training variables
CRITERION = nn.CrossEntropyLoss()
EPOCHS = 25
PATIENCE = 5

# set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# kaggle absolute path
# DATA_PATH = "../input/indoor-scenes-cvpr-2019/indoorCVPR_09/Images/"
DATA_PATH = "../input/top10indoorscenes/reduced-indoor-scenes/images/"

TRANSFORMS = transforms.Compose(
[
transforms.Resize(IMG_SIZE),
transforms.ToTensor(),
]
)

# load data, split, get loaders
dataset = datasets.ImageFolder(DATA_PATH, transform=TRANSFORMS)
n_classes = len(dataset.classes)
train_set, val_set, test_set = random_split(dataset, SPLIT)
trainloader = DataLoader(
train_set, BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
validloader = DataLoader(val_set, BATCH_SIZE, num_workers=2, pin_memory=True)

# models to train
cnn_models = {
"Simple_CNN": SimpleCNN(),
"ResNet18_Tuned": get_custom_resnet18(
n_classes=n_classes, pretrained=True, freeze=False
),
"AlexNet_Tuned": get_custom_alexnet(
n_classes=n_classes, pretrained=True, freeze=False
),
"DenseNet121_Tuned": get_custom_densenet121(
n_classes=n_classes, pretrained=True, freeze=False
),
}

# to track all training and validation losses and accuracies
train_losses, train_accs = {}, {}
valid_losses, valid_accs = {}, {}

# train models
for name, model in cnn_models.items():
print(f"{name} training started...")
print(
"===================================================================================================="
)

model.to(DEVICE)
optimizer = torch.optim.Adam(model.parameters())

tuned_model, train_loss, train_acc, valid_loss, valid_acc = train_model(
model=model,
device=DEVICE,
criterion=CRITERION,
optimizer=optimizer,
trainloader=trainloader,
validloader=validloader,
epochs=EPOCHS,
patience=PATIENCE,
save=True,
name=f"{name}",
)

cnn_models[name] = tuned_model
train_losses[name] = train_loss
train_accs[name] = train_acc
valid_losses[name] = valid_loss
valid_accs[name] = valid_acc

print(
"====================================================================================================\n"
)

0 comments on commit bfc18f3

Please sign in to comment.