-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
993a69a
commit 8ce5490
Showing
11 changed files
with
357 additions
and
1 deletion.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
albumentations==1.3.1 | ||
certifi==2023.5.7 | ||
charset-normalizer==3.1.0 | ||
colorama==0.4.6 | ||
contourpy==1.0.7 | ||
cycler==0.11.0 | ||
efficientnet-pytorch==0.7.1 | ||
filelock==3.12.2 | ||
fonttools==4.40.0 | ||
fsspec==2023.6.0 | ||
huggingface-hub==0.15.1 | ||
idna==3.4 | ||
imageio==2.31.1 | ||
importlib-resources==5.12.0 | ||
Jinja2==3.1.2 | ||
joblib==1.2.0 | ||
kiwisolver==1.4.4 | ||
lazy_loader==0.2 | ||
MarkupSafe==2.1.3 | ||
matplotlib==3.7.1 | ||
mpmath==1.3.0 | ||
munch==3.0.0 | ||
networkx==3.1 | ||
numpy==1.24.3 | ||
opencv-python==4.7.0.72 | ||
opencv-python-headless==4.7.0.72 | ||
packaging==23.1 | ||
patchify==0.2.3 | ||
Pillow==9.5.0 | ||
pretrainedmodels==0.7.4 | ||
pyparsing==3.0.9 | ||
python-dateutil==2.8.2 | ||
PyWavelets==1.4.1 | ||
PyYAML==6.0 | ||
qudida==0.0.4 | ||
requests==2.31.0 | ||
safetensors==0.3.1 | ||
scikit-image==0.21.0 | ||
scikit-learn==1.2.2 | ||
scipy==1.10.1 | ||
segmentation-models-pytorch==0.3.3 | ||
six==1.16.0 | ||
split-folders==0.5.1 | ||
sympy==1.12 | ||
threadpoolctl==3.1.0 | ||
tifffile==2023.4.12 | ||
timm==0.9.2 | ||
torch==2.0.1 | ||
torchvision==0.15.2 | ||
tqdm==4.65.0 | ||
typing_extensions==4.6.3 | ||
urllib3==2.0.3 | ||
zipp==3.15.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import os | ||
import splitfolders | ||
import torch | ||
from torch.utils.data import DataLoader | ||
import segmentation_models_pytorch as smp | ||
import segmentation_models_pytorch.utils | ||
|
||
from utils.patching import patching, discard_useless_patches | ||
from utils.preprocess import get_training_augmentation, get_preprocessing | ||
from utils.dataset import SegmentationDataset | ||
|
||
########## config ########### | ||
|
||
BATCH_SIZE = 16 | ||
ENCODER = 'efficientnet-b0' | ||
ENCODER_WEIGHTS = 'imagenet' | ||
CLASSES = ['background', 'building', 'woodland', 'water'] # not training on 'road' class since it's instances in the data is too less | ||
ACTIVATION = 'softmax2d' # could be None for logits or 'softmax2d' for multiclass segmentation | ||
DEVICE = 'cpu' | ||
EPOCHS = 20 | ||
|
||
root_directory = "data/train" | ||
|
||
img_dir = os.path.join(root_directory, "images") | ||
mask_dir = os.path.join(root_directory, "masks") | ||
|
||
patch_size = 512 | ||
|
||
patches_img_dir = os.path.join(f"patches_{patch_size}", "images") | ||
patches_img_dir = os.path.join(root_directory, patches_img_dir) | ||
os.makedirs(patches_img_dir, exist_ok=True) | ||
patches_mask_dir = os.path.join(f"patches_{patch_size}", "masks") | ||
patches_mask_dir = os.path.join(root_directory, patches_mask_dir) | ||
os.makedirs(patches_mask_dir, exist_ok=True) | ||
|
||
model_dir = "models" | ||
|
||
############################# | ||
|
||
print() | ||
print("Dividing images into patches...") | ||
patching(img_dir, patches_img_dir, patch_size) | ||
print("Dividing images into patches completed successfull!") | ||
|
||
print() | ||
print("Dividing masks into patches...") | ||
patching(mask_dir, patches_mask_dir, patch_size) | ||
print("Dividing masks into patches completed successfull!") | ||
|
||
discard_useless_patches(patches_img_dir, patches_mask_dir) | ||
|
||
|
||
|
||
input_folder = patches_img_dir.strip("images") | ||
print(input_folder) | ||
output_folder = os.path.join(root_directory, "train_val_test") | ||
print(output_folder) | ||
|
||
os.makedirs(output_folder, exist_ok=True) | ||
|
||
# Split with a ratio. | ||
# To split into training, validation, and testing set, set a tuple to `ratio`, i.e, `(.8, .1, .1)`. | ||
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2), group_prefix=None, move=False) # splitting in training and validation only | ||
|
||
train_dir = os.path.join(output_folder, "train") | ||
val_dir = os.path.join(output_folder, "val") | ||
# test_dir = os.path.join(output_folder, "test") | ||
|
||
x_train_dir = os.path.join(train_dir, "images") | ||
y_train_dir = os.path.join(train_dir, "masks") | ||
|
||
x_val_dir = os.path.join(val_dir, "images") | ||
y_val_dir = os.path.join(val_dir, "masks") | ||
|
||
|
||
|
||
|
||
|
||
|
||
# create segmentation model with pretrained encoder | ||
model = smp.Unet( | ||
encoder_name=ENCODER, | ||
encoder_weights=ENCODER_WEIGHTS, | ||
classes=len(CLASSES), | ||
activation=ACTIVATION, | ||
) | ||
|
||
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS) | ||
|
||
|
||
|
||
train_dataset = SegmentationDataset( | ||
x_train_dir, | ||
y_train_dir, | ||
augmentation=get_training_augmentation(), | ||
preprocessing=get_preprocessing(preprocessing_fn), | ||
classes=CLASSES, | ||
) | ||
|
||
val_dataset = SegmentationDataset( | ||
x_val_dir, | ||
y_val_dir, | ||
preprocessing=get_preprocessing(preprocessing_fn), | ||
classes=CLASSES, | ||
) | ||
|
||
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) | ||
valid_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) | ||
|
||
|
||
loss = smp.utils.losses.DiceLoss() | ||
metrics = [ | ||
smp.utils.metrics.IoU(threshold=0.5) | ||
] | ||
|
||
optimizer = torch.optim.Adam([ | ||
dict(params=model.parameters(), lr=0.0003), | ||
]) | ||
|
||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min') | ||
|
||
|
||
# create epoch runners | ||
# it is a simple loop of iterating over dataloader`s samples | ||
train_epoch = smp.utils.train.TrainEpoch( | ||
model, | ||
loss=loss, | ||
metrics=metrics, | ||
optimizer=optimizer, | ||
device=DEVICE, | ||
verbose=True, | ||
) | ||
|
||
valid_epoch = smp.utils.train.ValidEpoch( | ||
model, | ||
loss=loss, | ||
metrics=metrics, | ||
device=DEVICE, | ||
verbose=True, | ||
) | ||
|
||
|
||
max_score = 0 | ||
|
||
for i in range(0, EPOCHS): | ||
|
||
print('\nEpoch: {}'.format(i)) | ||
train_logs = train_epoch.run(train_loader) | ||
valid_logs = valid_epoch.run(valid_loader) | ||
|
||
# Do something (save model, change lr, etc.) | ||
if max_score < valid_logs['iou_score']: | ||
max_score = valid_logs['iou_score'] | ||
torch.save(model, f'{model_dir}/landcover_unet_{ENCODER}_epochs{i}_patch{patch_size}_batch{BATCH_SIZE}.pth') | ||
print('Model saved!') | ||
|
||
scheduler.step(valid_logs['dice_loss']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
class SegmentationDataset(Dataset): | ||
|
||
""" | ||
landcover.ai dataset. Read images, apply augmentation and preprocessing transformations. | ||
Args: | ||
images_dir (str): path to images folder | ||
masks_dir (str): path to segmentation masks folder | ||
class_values (list): values of classes to extract from segmentation mask | ||
augmentation (albumentations.Compose): data transfromation pipeline | ||
(e.g. flip, scale, etc.) | ||
preprocessing (albumentations.Compose): data preprocessing | ||
(e.g. noralization, shape manipulation, etc.) | ||
""" | ||
|
||
CLASSES = ['background', 'building', 'woodland', 'water', 'road'] | ||
|
||
def __init__( | ||
self, | ||
images_dir, | ||
masks_dir, | ||
classes=None, | ||
augmentation=None, | ||
preprocessing=None, | ||
): | ||
self.ids = os.listdir(images_dir) | ||
self.images = [os.path.join(images_dir, image_id) for image_id in self.ids] | ||
self.masks = [os.path.join(masks_dir, image_id) for image_id in self.ids] | ||
|
||
# convert str names to class values on masks | ||
self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes] | ||
|
||
self.augmentation = augmentation | ||
self.preprocessing = preprocessing | ||
|
||
def __getitem__(self, i): | ||
|
||
# read data | ||
image = cv2.imread(self.images[i]) | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
image = image / 255 | ||
mask = cv2.imread(self.masks[i], 0) | ||
|
||
# extract certain classes from mask (e.g. cars) | ||
masks = [(mask == v) for v in self.class_values] | ||
mask = np.stack(masks, axis=-1).astype('float') | ||
|
||
# apply augmentations | ||
if self.augmentation: | ||
sample = self.augmentation(image=image, mask=mask) | ||
image, mask = sample['image'], sample['mask'] | ||
|
||
# apply preprocessing | ||
if self.preprocessing: | ||
sample = self.preprocessing(image=image, mask=mask) | ||
image, mask = sample['image'], sample['mask'] | ||
|
||
return image, mask | ||
|
||
def __len__(self): | ||
return len(self.ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import os | ||
import cv2 | ||
import numpy as np | ||
from patchify import patchify | ||
|
||
def patching(data_dir, patches_dir, patch_size): | ||
for filename in os.listdir(data_dir): | ||
if filename.endswith('.tif'): | ||
img = cv2.imread(os.path.join(data_dir, filename), 1) | ||
# cropping to have height and width perfectly divisible by patch_size | ||
max_height = (img.shape[0] // patch_size) * patch_size | ||
max_width = (img.shape[1] // patch_size) * patch_size | ||
img = img[0:max_height, 0:max_width] | ||
# patching | ||
print(f"Patchifying {filename}...") | ||
patches = patchify(img, (patch_size, patch_size, 3), step = patch_size) # non-overlapping | ||
print("Patches shape:", patches.shape) | ||
for i in range(patches.shape[0]): | ||
for j in range(patches.shape[1]): | ||
single_patch = patches[i, j, 0, :, :] # the 0 is an extra unncessary dimension added by patchify for multiple channels scenario | ||
cv2.imwrite(os.path.join(patches_dir, filename.replace(".tif", f"_patch_{i}_{j}.tif")), single_patch) | ||
|
||
def discard_useless_patches(patches_img_dir, patches_mask_dir): | ||
for filename in os.listdir(patches_mask_dir): | ||
img_path = os.path.join(patches_img_dir, filename) | ||
mask_path = os.path.join(patches_mask_dir, filename) | ||
mask = cv2.imread(mask_path) | ||
classes, count = np.unique(mask, return_counts = True) | ||
# If background class occupies more than 95% of the image, discard the image and mask | ||
if (count[0] / count.sum()) > 0.95: | ||
os.remove(img_path) | ||
os.remove(mask_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import albumentations as album | ||
|
||
def get_training_augmentation(): | ||
train_transform = [ | ||
album.HorizontalFlip(p=0.5), | ||
album.VerticalFlip(p=0.5), | ||
# album.ShiftScaleRotate(scale_limit=1.5, rotate_limit=45, shift_limit=0.1, p=1, border_mode=0), | ||
# album.GaussNoise(p=0.2), | ||
# album.Perspective(p=0.5), | ||
# album.OneOf( | ||
# [ | ||
# album.CLAHE(p=1), | ||
# album.RandomBrightnessContrast(p=1), | ||
# album.RandomGamma(p=1), | ||
# ], | ||
# p=0.9, | ||
# ), | ||
# album.OneOf( | ||
# [ | ||
# album.Sharpen(p=1), | ||
# album.Blur(blur_limit=3, p=1), | ||
# album.MotionBlur(blur_limit=3, p=1), | ||
# ], | ||
# p=0.9, | ||
# ), | ||
] | ||
return album.Compose(train_transform) | ||
|
||
|
||
def to_tensor(x, **kwargs): | ||
return x.transpose(2, 0, 1).astype('float32') | ||
|
||
|
||
def get_preprocessing(preprocessing_fn): | ||
"""Construct preprocessing transform | ||
Args: | ||
preprocessing_fn (callbale): data normalization function | ||
(can be specific for each pretrained neural network) | ||
Return: | ||
transform: albummentations.Compose | ||
""" | ||
_transform = [ | ||
album.Lambda(image=preprocessing_fn), | ||
album.Lambda(image=to_tensor, mask=to_tensor), | ||
] | ||
return album.Compose(_transform) |