- python3.x
- pytorch
- torchvision to load the datasets, perform image transforms
- pandas for logging to csv
- bokeh for training visualization
- scikit-learn for kmeans clustering
- mlflow for logging To install requirements run:
pip install torch torchvision bokeh pandas sklearn mlflow
NVIDIA GPU / cuda support
- To run this code you need validation set from ILSVRC2012 data
- Configure your dataset path by providing --data "PATH_TO_ILSVRC" or copy ILSVRC dir to ~/datasets/ILSVRC2012.
- To get the ILSVRC2012 data, you should register on their site for access: http://www.image-net.org/
To improve performance GEMMLOWP quantization was implemented in cuda and requires to compile kernels.
- Create virtual environment for python3 and activate:
virtualenv --system-site-packages -p python3 venv3
. ./venv3/bin/activate
- build kernels
cd kernels
./build_all.sh
Post-training quantization of Res50
Note that accuracy results could have 0.5% variance due to data shuffling.
- Experiment W4A4 naive:
python inference/inference_sim.py -a resnet50 -b 512 -pcq_w -pcq_a -sh --qtype int4 -qw int4
- Prec@1 62.154 Prec@5 84.252
- Experiment W4A4 + ACIQ + Bit Alloc(A) + Bit Alloc(W) + Bias correction:
python inference/inference_sim.py -a resnet50 -b 512 -pcq_w -pcq_a -sh --qtype int4 -qw int4 -c laplace -baa -baw -bcw
- Prec@1 73.330 Prec@5 91.334
We solve eq. 6 numerically to find optimal clipping value α for both Laplace and Gaussian prior.
Numerical solution source code:
optimal_alpha.ipynb
Given a quota on the total number of bits allowed to be written to memory, the optimal bit width assignment Mi for channel i is the following.
We observe an inherent bias in the mean and the variance of the weight values following their quantization.
We calculate this bias using equation 12.
Then, we compensate for the bias for each channel of W as follows:
In order to quantize tensor to M bit with optimal clipping we use GEMMLOWP quantization with small modification. We replace dynamic range in scale computation by 2*alpha where alpha is optimal clipping value.
Quantization code can be found here: int_quantizer.py