Code release for Test-Time Training with Self-Supervision for Generalization under Distribution Shifts.
This code produces our results on CIFAR-10-C and CIFAR-10.1.
The ImageNet results are produced by this repository.
- Our code requires pytorch version 1.0 or higher, with at least one modern GPU of adequate memory.
- We ran our code with python 3.7. Compatibility with python 2 is possible maybe with some modifications.
- Most of the packages used should be included with anaconda, except maybe two small utilities:
- Download the two datasets into the same folder:
- CIFAR-10-C (Hendrycks and Dietterich) from this repository, which links to this shared storage.
- CIFAR-10.1 (Recht et al.) from this repository.
- Clone our repository with
git clone https://github.com/yueatsprograms/ttt_cifar_release
. - Inside the repository, set the data folder to where the datasets are stored by editing:
--dataroot
argument inmain.py
.--dataroot
argument inbaseline.py
.dataroot
variable inscript_test_c10.py
.
- Run
script.sh
for the main results, andscript_baseline.sh
for the baseline results. - The results are stored in the respective folders in
results/
. - Once everything is finished, the results can be compiled and visualized with the following utilities:
show_table.py
parses the results into tables and prints them.show_plot.py
makes bar plots like those in our paper, and prints the tables in latex format; requires first runningshow_table.py
.show_grad.py
makes the gradient correlation plot in our paper.