Implementation of models for my bachelor's thesis.
The following image-to-image translation models are implemented using PyTorch and PyTorch Lightning:
- Pix2Pix (Isola et al. 2018);
- Attention U-net (Oktay et al. 2018);
- Residual U-net with the following basic blocks:
- Res18 (He et al. 2015);
- Res50 (He et al. 2015);
- ResV2 (He et al. 2016);
- ResNeXt (Xie et al. 2017).
- Trans U-net (Chen et al. 2021);
- Palette (Saharia et al. 2022).
More models can easily be added by using the UnetWrapper
class.
The following loss functions are implemented:
- GAN loss, using Pix2Pix' discriminator (to change the used adversarial
network, you must change the
Discriminator
class inmodels/wrapper.py
); - MSE loss;
- SSIM loss;
- PSNR loss.
- Combination of SSIM and PSNR loss.
The organisation of your data does not matter. The only important thing is the data file, a YAML file containing a list of input-ground truth entries. The input and ground truth files must be relative to the directory of the data file. For example:
- input: input/00001.png
ground_truth: ground_truth/00001.png
- input: input/00002.png
ground_truth: ground_truth/00002.png
- input: input/00003.png
ground_truth: ground_truth/00003.png
To train a model, run the following:
python main.py <run name> <options>
When training, the model with the highest SSIM on the validation dataset will be selected as the "best" checkpoint.
To test a trained model, run the following:
python report.py <report name> <options>
It essentially takes a model checkpoint and test data file as input and outputs metrics and information about the model. The following metrics are reported:
- SSIM per image;
- PSNR per image;
- Mean SSIM;
- Mean PSNR;
- Mean RMSE;
- FLOPs;
- Parameter count;
- SSIM over depth (vertically) of the image (this is only relevant for PAI reconstruction).
- Outputs of the model.