The repository contains software library for Data Augmentation Services
python 3.x
numpy > 1.13
scipy > 0.19
pillow > 5.2
tensorflow (gpu) > 1.3
MATLAB
Contains two parts (i) Training GAN and (ii) Training Ensemble Classifier
Execute the python script main.py located at DAS/Stage-1/Train_GAN/ by typing the command on the terminal
python main.py
We provide modified version of DCGAN taken from https://github.com/carpedm20/DCGAN-tensorflow The discriminator and the generator are conditioned on the labels of the images and have less layers than the DCGAN model available at the link above. The code DCGAN_Modified.py is located at DAS/Stage-1/Train_GAN/
This will download CIFAR-10 dataset automatically to the path specified.
The GAN will be trained for category 0 against all with the default parameters specified in the file main.py.
Synthetically Generated Images and associated Labels will be saved in DAS/Stage-1/Train_GAN/Geneated_Data folder.
Example files already exists in this path, Images_10_1_2.mat and Labels_10_0_1.mat
Naming Convention Images_alpha_CategoryA_CategoryB and similarly Labels_alpha_CategoryA_CategoryB
If you wish to change the categories on which GAN is trained then please edit file DAS/Stage-1/Train_GAN/DCGAN_Modified.py
Line 162, fixed_label = 0 and
Line 163, iter_labels = np.arange((fixed_label + 1), 10)
If you want to specify the split ratio of training dataset while training GAN then edit
Line 159, alpha = 0.1
To train Ensemble classifier execute the main MATLAB file TrainEnsemble.m found at DAS/Stage-1/TrainEnsemClass/
We train SVM, k-NN and naive Bayes available at MATLAB-R2018a The codes and necessary functions are in the folder DAS/Stage-1/Train_EnsemClass/
Trained parameters of the classifier's will get save in a mat file DAS/Stage-1/TrainEnsemClass, example file exists in the folder name MODEL_X_Y.mat
X : Category 1
Y : Category 2
You can specify on which labels you want to train your ensemble classifier then edit file TrainEnsemble.m
Line 14, fixed_label = 1
Line 15, selected_labels = [(fixed_label+1):10]
Once training of then GAN and ensemble classifier is finished and outputs are saved in their corresponding locations. Move to DAS/Stage-2/ for filtering synthetic Images and obtaining performance measuer of CNN trained on augmented datasets.
Execute the main MATLAB file Filter_Images.m found at DAS/Stage-2/Filter_Unbiased_Images/ to filter the synthetic images generated by the GAN.
Path of the training data, the saved model, and the generated data is required. They are already set in the Filter_Images.m file. However you can modify them to any locations, the details are as below:
Line 16, path of the trained ensemble classifier's model
Line 17, path of the training data
Line 78, path of the generated data
Once the code terminates output file named Batches_alpha_CategoryA_CategoryB will be saved in
DAS/Stage-2/Filter_Unbiased_Images/Filtered_Images/
This file contains the test data and its labels, batches of training and filter images for 3-fold cross-validation.
Execute python main.py on the terminal to train VGG-style CNN adopted from https://github.com/soumith/DeepLearningFrameworks/blob/master/Tensorflow_CIFAR.ipynb. The file is available at DAS/Stage-2/Train_CNN/
This will train the CNN on Augmenetd dataset obtained from the filtering stage i.e. Stage - 2 of Data Augmentation Services
The outputs will get saved in DAS/Stage-2/Train_CNN/results folder with name Accuracy_alpha_CategoryA_CategoryB and Pred_Labels_alpha_CategoryA_CategoryB.
If you wish to train the CNN on true training CIFAR dataset then edit the script main.py
Line 5, change from VGG_CNN_CIFAR import VGG to from VGG_CNN_Baseline import VGG
Execute the main MATLAB script DAS/Stage-2/Calculate_Bias_Variance/Calculate_Performace.m to obtain the measures of the bias, the variance and, the accuracy of our model after training on the augmenetd dataset.
The script requires path to the directory where results after saved by the CNN model and the path to the directory where the augmeneted data is stored. Both the paramters are at the following Line numbers in the script
Line 5, path to the CNN results directory
Line 6, path to the augmented dataset