Skip to content

This is the implementation of the original UNet Architecture with some optimizations. The datasets being used for binary image segmentation is the Carvana dataset and for multiclass segmentation on the Cityscape dataset is being added

Notifications You must be signed in to change notification settings

Ayushman-Choudhuri/unet_segmentation_from_scratch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

68 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

UNET Image Segmentation

This project implements a slightly modified UNET architecture and uses it to perform binary segmentation as well as multiclass segmentation.

The Image segmentation has been performed on two datasets. The carvana dataset and the cityscapes dataset. The carvana dataset was initially used to test the perfomance of the UNET model for a simple binary segmentation task. After that the UNET model has been applied for multiclass segmentation on the cityscapes dataset.

Changes in UNET:
Compared to the original model [1], I have added a 2d Batch Normalization layers [2] between every Conv2d and ReLU layer in the original UNET architecture. This would accelerate learning by enabling a smoother optimization landscape making the optimization algorithms converge faster.

1. Branches

main: The main branch holds the implementation of the binary segmentation on the carvana dataset.

cityscapes: The cityscapes branch holds the implementation of the multiclass segmentation using the cityscapes dataset

2. Setup - Binary Segmentation

After cloning the repository , follow the following steps to run project setup.

Note: All commands are to be run from the project root folder.

2.1: Install Dependencies

Step1: Create a python virtual environment and activate it

python3 -m venv unetseg_venv && source unetseg_venv/bin/activate

Step2: Install dependencies from requirements.txt

pip3 install -r requirements.txt

Step3: Install Pytorch based on your system configs Link:

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

2.2: Folder Structure - Binary Segmentation

Create the following top folder structure

unet_segmentation_from_scratch
├── configs
├── datasets
│   └── carvana
├── images
├── logs
│   ├── checkpoints
│   └── saved_images
├── runs
├── src
├── unetseg_venv
├── .gitignore
├── README.md
└── requirements.txt

3. Dataset - Binary Segmentation

Download the Carvana Dataset from the kaggle page of the Carvana Image Masking Challenge. We need just the data from the folders named train and train_masks. The train folder contains the training set of images and the train_masks contains the .gif mask file for each of the corresponding training images.

We would be deriving the train , test and validation datasets needed to train , validate and test the Unet model from the train folder of the Carvana Dataset. The train.zip file contains 5088 RGB images with size of 1918 x 1280 pixels.

From the 5088 images, separate out 50 images in val_images and 50 in the test_images and also move the corresponding masks in the test_masks and val_masks folder.

The dataset folder structure should look like this after this:

datasets
└── carvana
    ├── test_images      # 50 images in .jpg format
    ├── test_masks       # 50 masks in .gif format
    ├── train_images     # 4988 images in .jpg format
    ├── train_masks      # 4988 masks in .gif format
    ├── val_images       # 50 images in .jpg format
    └── val_masks        # 50 masks in .gif format

5. Running the Project - Binary Segmentation

5.1 Setting up dataset configurations

In the config_carvana.yaml file, tune the training hyperparameters to adjust for your training hardware setup

5.2 Running training loop:

python3 src/train.py

6. Results

7. References

[1] U-Net: Convolutional Networks for Biomedical Image Segmentation

[2] Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

About

This is the implementation of the original UNet Architecture with some optimizations. The datasets being used for binary image segmentation is the Carvana dataset and for multiclass segmentation on the Cityscape dataset is being added

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published