This repository is the official accompaniment to A General Framework for Robust G-Invariance in G-Equivariant Networks (2023) by Sophia Sanborn and Nina Miolane, published in the Proceedings of the 37th Conference on Neural Information Processing Systems (NeurIPS).
To install the requirements and package, run:
pip install -r requirements.txt
python install -e .
To download the datasets:
- Download the zip file here.
- Place the file in the top node of this directory, i.e. in
gtc-invariance/
. - Run:
unzip datasets.zip rm -r datasets.zip
The full set of hyperparameters and training configurations are specified in the config files in the configs/
folder. To train a model on a particular experiment, you will call the following:
scripts/run_data_agent.py --config [name of config]
scripts/run_train_agent.py --config [name of config]
The first call will generate the transformed dataset, and the second will train the model on that dataset. The config
argument should be followed by the name of a particular config file from configs/experiments
, e.g. o2mnist_d16_maxpool
. The .py
extension of the config should be excluded. Each of the configs in the configs/experiments
folder combines various model, trainer, etc configs also specified in the configs
folder. The scripts are set up to log the model with Weights & Biases. A user's wandb entity and project directories should be specified in configs/logger
.
This repository is licensed under the MIT License.