Training a ResNet classifier using pre-trained BigGAN models
[Report]
Using the base implementation of BigGAN provided by Andrew Brock.
Achieves ~88.5% accuracy on CIFAR-10 (when using data transformations and classifier filtering), versus ~94.3 when trained using real data for the same number of optimization steps. With five GANs, reaches ~91% accuracy.
Click HERE to download the final research report.
Training a classifier requires:
- pre-trained classifier weights in:
./classifier/weights/model_name.pth - pre-trained BigGAN weights in:
./weights/weights_folder_name/
To run the script:
python3 train_classifier.py [options]
Parameters are as follows:
Input/Output
model: Weights file to use for the GAN (of the form:./weights/model_name/G_ema.pthif single GAN,./weights/model_name/gan_multi_n/G_ema.pthifnGANs are used)classifier_model: Weights file to use for the filtering classifier (of the form: ./classifiers/weights/class_model_name.pth)ofile: Output file name (default:trained_net)
Training
batch_size: Size of each batch (same for generation/filtering/training, default: 64)num_batches: Number of batches per class to train the classifier with (default: 1)epochs: Number of epochs to train the classifier for (default: 10)
Classifier filtering
filter_samples: Enable classifier-filtering of generated images (default:False)threshold: Threshold probability for classifier filtering (default:0.9)
Multi-GANs
multi_gans: Sample using multiple GANs (default:None, integer value)gan_weights: If using multi-GANs, specify weights for each GAN (default: sample from each GAN with equiprobability)
Other
truncate: Sample latent z from a truncated normal (default: no truncation, float format).fixed_dset: Use a fixed generated dataset for training (of size:batch_size*num_batches*num_classes, default:False)transform: Apply image transformations to generated images (default:False)
To sample from GAN weights:
python3 sample.py [options]
Parameters are as follows:
Input/Output
model: Same as aboveofile: Output file name (default:trained_net)torch_format: Save NPZ images as float tensors instead ofuint8(default:False)
Generation
num_samples: Number of samples to generate (default: 10)class: Class to sample from (in[[0,K-1]]forKclasses, default: sample sequentiallynum_samples/kfor all classes.)random_k: Sample classes randomly (default:False)multi_gans: Generate samples using multiple GANs (default:None, integer value)
Other
transform,truncate: Same as above
Some bash scripts are already in the folder ./scripts/, to run classifier training sessions with various parameters.
-
[Brock et al., 2018] Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale GAN training for high fidelity natural image synthesis, 2018.
-
[Pham et al., 2019] Thanh Dat Pham, Anuvabh Dutt, Denis Pellerin, and Georges Quénot. Classifier Training from a Generative Model. In CBMI 2019 - 17th International Conference on Content-Based Multimedia Indexing, Dublin, Ireland, September 2019.
