Skip to content

Commit

Permalink
Added train.py with utils
Browse files Browse the repository at this point in the history
  • Loading branch information
souvikmajumder26 committed Jun 14, 2023
1 parent 993a69a commit 8ce5490
Show file tree
Hide file tree
Showing 11 changed files with 357 additions and 1 deletion.
Binary file added assets/all_classes1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/all_classes2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/select_classes1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/select_classes2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars:
test_log_name: test.log
log_level: 'DEBUG'
patch_size: 512
model_name: landcover_unet_efficientnet-b0_epochs18_patch512_batch16.pth
model_name: trained_landcover_unet_efficientnet-b0_epochs18_patch512_batch16.pth
encoder: 'efficientnet-b0'
encoder_weights: 'imagenet'
train_classes: ['background', 'building', 'woodland', 'water']
Expand Down
53 changes: 53 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
albumentations==1.3.1
certifi==2023.5.7
charset-normalizer==3.1.0
colorama==0.4.6
contourpy==1.0.7
cycler==0.11.0
efficientnet-pytorch==0.7.1
filelock==3.12.2
fonttools==4.40.0
fsspec==2023.6.0
huggingface-hub==0.15.1
idna==3.4
imageio==2.31.1
importlib-resources==5.12.0
Jinja2==3.1.2
joblib==1.2.0
kiwisolver==1.4.4
lazy_loader==0.2
MarkupSafe==2.1.3
matplotlib==3.7.1
mpmath==1.3.0
munch==3.0.0
networkx==3.1
numpy==1.24.3
opencv-python==4.7.0.72
opencv-python-headless==4.7.0.72
packaging==23.1
patchify==0.2.3
Pillow==9.5.0
pretrainedmodels==0.7.4
pyparsing==3.0.9
python-dateutil==2.8.2
PyWavelets==1.4.1
PyYAML==6.0
qudida==0.0.4
requests==2.31.0
safetensors==0.3.1
scikit-image==0.21.0
scikit-learn==1.2.2
scipy==1.10.1
segmentation-models-pytorch==0.3.3
six==1.16.0
split-folders==0.5.1
sympy==1.12
threadpoolctl==3.1.0
tifffile==2023.4.12
timm==0.9.2
torch==2.0.1
torchvision==0.15.2
tqdm==4.65.0
typing_extensions==4.6.3
urllib3==2.0.3
zipp==3.15.0
157 changes: 157 additions & 0 deletions src/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import os
import splitfolders
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import segmentation_models_pytorch.utils

from utils.patching import patching, discard_useless_patches
from utils.preprocess import get_training_augmentation, get_preprocessing
from utils.dataset import SegmentationDataset

########## config ###########

BATCH_SIZE = 16
ENCODER = 'efficientnet-b0'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['background', 'building', 'woodland', 'water'] # not training on 'road' class since it's instances in the data is too less
ACTIVATION = 'softmax2d' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cpu'
EPOCHS = 20

root_directory = "data/train"

img_dir = os.path.join(root_directory, "images")
mask_dir = os.path.join(root_directory, "masks")

patch_size = 512

patches_img_dir = os.path.join(f"patches_{patch_size}", "images")
patches_img_dir = os.path.join(root_directory, patches_img_dir)
os.makedirs(patches_img_dir, exist_ok=True)
patches_mask_dir = os.path.join(f"patches_{patch_size}", "masks")
patches_mask_dir = os.path.join(root_directory, patches_mask_dir)
os.makedirs(patches_mask_dir, exist_ok=True)

model_dir = "models"

#############################

print()
print("Dividing images into patches...")
patching(img_dir, patches_img_dir, patch_size)
print("Dividing images into patches completed successfull!")

print()
print("Dividing masks into patches...")
patching(mask_dir, patches_mask_dir, patch_size)
print("Dividing masks into patches completed successfull!")

discard_useless_patches(patches_img_dir, patches_mask_dir)



input_folder = patches_img_dir.strip("images")
print(input_folder)
output_folder = os.path.join(root_directory, "train_val_test")
print(output_folder)

os.makedirs(output_folder, exist_ok=True)

# Split with a ratio.
# To split into training, validation, and testing set, set a tuple to `ratio`, i.e, `(.8, .1, .1)`.
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2), group_prefix=None, move=False) # splitting in training and validation only

train_dir = os.path.join(output_folder, "train")
val_dir = os.path.join(output_folder, "val")
# test_dir = os.path.join(output_folder, "test")

x_train_dir = os.path.join(train_dir, "images")
y_train_dir = os.path.join(train_dir, "masks")

x_val_dir = os.path.join(val_dir, "images")
y_val_dir = os.path.join(val_dir, "masks")






# create segmentation model with pretrained encoder
model = smp.Unet(
encoder_name=ENCODER,
encoder_weights=ENCODER_WEIGHTS,
classes=len(CLASSES),
activation=ACTIVATION,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)



train_dataset = SegmentationDataset(
x_train_dir,
y_train_dir,
augmentation=get_training_augmentation(),
preprocessing=get_preprocessing(preprocessing_fn),
classes=CLASSES,
)

val_dataset = SegmentationDataset(
x_val_dir,
y_val_dir,
preprocessing=get_preprocessing(preprocessing_fn),
classes=CLASSES,
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


loss = smp.utils.losses.DiceLoss()
metrics = [
smp.utils.metrics.IoU(threshold=0.5)
]

optimizer = torch.optim.Adam([
dict(params=model.parameters(), lr=0.0003),
])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')


# create epoch runners
# it is a simple loop of iterating over dataloader`s samples
train_epoch = smp.utils.train.TrainEpoch(
model,
loss=loss,
metrics=metrics,
optimizer=optimizer,
device=DEVICE,
verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
model,
loss=loss,
metrics=metrics,
device=DEVICE,
verbose=True,
)


max_score = 0

for i in range(0, EPOCHS):

print('\nEpoch: {}'.format(i))
train_logs = train_epoch.run(train_loader)
valid_logs = valid_epoch.run(valid_loader)

# Do something (save model, change lr, etc.)
if max_score < valid_logs['iou_score']:
max_score = valid_logs['iou_score']
torch.save(model, f'{model_dir}/landcover_unet_{ENCODER}_epochs{i}_patch{patch_size}_batch{BATCH_SIZE}.pth')
print('Model saved!')

scheduler.step(valid_logs['dice_loss'])
66 changes: 66 additions & 0 deletions src/utils/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import cv2
import numpy as np
from torch.utils.data import Dataset

class SegmentationDataset(Dataset):

"""
landcover.ai dataset. Read images, apply augmentation and preprocessing transformations.
Args:
images_dir (str): path to images folder
masks_dir (str): path to segmentation masks folder
class_values (list): values of classes to extract from segmentation mask
augmentation (albumentations.Compose): data transfromation pipeline
(e.g. flip, scale, etc.)
preprocessing (albumentations.Compose): data preprocessing
(e.g. noralization, shape manipulation, etc.)
"""

CLASSES = ['background', 'building', 'woodland', 'water', 'road']

def __init__(
self,
images_dir,
masks_dir,
classes=None,
augmentation=None,
preprocessing=None,
):
self.ids = os.listdir(images_dir)
self.images = [os.path.join(images_dir, image_id) for image_id in self.ids]
self.masks = [os.path.join(masks_dir, image_id) for image_id in self.ids]

# convert str names to class values on masks
self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]

self.augmentation = augmentation
self.preprocessing = preprocessing

def __getitem__(self, i):

# read data
image = cv2.imread(self.images[i])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = image / 255
mask = cv2.imread(self.masks[i], 0)

# extract certain classes from mask (e.g. cars)
masks = [(mask == v) for v in self.class_values]
mask = np.stack(masks, axis=-1).astype('float')

# apply augmentations
if self.augmentation:
sample = self.augmentation(image=image, mask=mask)
image, mask = sample['image'], sample['mask']

# apply preprocessing
if self.preprocessing:
sample = self.preprocessing(image=image, mask=mask)
image, mask = sample['image'], sample['mask']

return image, mask

def __len__(self):
return len(self.ids)
32 changes: 32 additions & 0 deletions src/utils/patching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import os
import cv2
import numpy as np
from patchify import patchify

def patching(data_dir, patches_dir, patch_size):
for filename in os.listdir(data_dir):
if filename.endswith('.tif'):
img = cv2.imread(os.path.join(data_dir, filename), 1)
# cropping to have height and width perfectly divisible by patch_size
max_height = (img.shape[0] // patch_size) * patch_size
max_width = (img.shape[1] // patch_size) * patch_size
img = img[0:max_height, 0:max_width]
# patching
print(f"Patchifying {filename}...")
patches = patchify(img, (patch_size, patch_size, 3), step = patch_size) # non-overlapping
print("Patches shape:", patches.shape)
for i in range(patches.shape[0]):
for j in range(patches.shape[1]):
single_patch = patches[i, j, 0, :, :] # the 0 is an extra unncessary dimension added by patchify for multiple channels scenario
cv2.imwrite(os.path.join(patches_dir, filename.replace(".tif", f"_patch_{i}_{j}.tif")), single_patch)

def discard_useless_patches(patches_img_dir, patches_mask_dir):
for filename in os.listdir(patches_mask_dir):
img_path = os.path.join(patches_img_dir, filename)
mask_path = os.path.join(patches_mask_dir, filename)
mask = cv2.imread(mask_path)
classes, count = np.unique(mask, return_counts = True)
# If background class occupies more than 95% of the image, discard the image and mask
if (count[0] / count.sum()) > 0.95:
os.remove(img_path)
os.remove(mask_path)
48 changes: 48 additions & 0 deletions src/utils/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import albumentations as album

def get_training_augmentation():
train_transform = [
album.HorizontalFlip(p=0.5),
album.VerticalFlip(p=0.5),
# album.ShiftScaleRotate(scale_limit=1.5, rotate_limit=45, shift_limit=0.1, p=1, border_mode=0),
# album.GaussNoise(p=0.2),
# album.Perspective(p=0.5),
# album.OneOf(
# [
# album.CLAHE(p=1),
# album.RandomBrightnessContrast(p=1),
# album.RandomGamma(p=1),
# ],
# p=0.9,
# ),
# album.OneOf(
# [
# album.Sharpen(p=1),
# album.Blur(blur_limit=3, p=1),
# album.MotionBlur(blur_limit=3, p=1),
# ],
# p=0.9,
# ),
]
return album.Compose(train_transform)


def to_tensor(x, **kwargs):
return x.transpose(2, 0, 1).astype('float32')


def get_preprocessing(preprocessing_fn):
"""Construct preprocessing transform
Args:
preprocessing_fn (callbale): data normalization function
(can be specific for each pretrained neural network)
Return:
transform: albummentations.Compose
"""
_transform = [
album.Lambda(image=preprocessing_fn),
album.Lambda(image=to_tensor, mask=to_tensor),
]
return album.Compose(_transform)

0 comments on commit 8ce5490

Please sign in to comment.