This repository contains code for training expert trajectories and distilling synthetic data for the paper: Dataset Distillation by Automatic Training Trajectories. The listed is the steps to run the code.
- Set up enveriments.
- Create an wandb account for monitoring distillation process (highly recommended but optional)
- Create buffers which contain expert trajectories.
- Distill and test synthetic datasets
You can install packages from requirements.txt with pip
pip install -r requirements.txt
In case of conflicts, you can check our packages details in requirements_detail.txt including versions and dependencies.
We use wandb to monitor our distillation process and system information. The tool is used by a lot of dataset distillation aythors. We strongly recommand reviewers to use this tool as well for visualize the process.
You can easily register an account at https://wandb.ai/site , then get an API key at https://wandb.ai/authorize?_gl.
When running the distillation code, you will only need to select the option declaring that you have an wandb account, and input your API key. After this, you will be able to monitor the distillation process in https://wandb.ai/site.
In case you don't want to use wandb, please just add in "mode='disabled'" into to wandb init in distill_ATT.py line 67.
Before doing distillation, we need to generate expert trajectories on the original dataset.
python buffer_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --train_epochs=50 --num_experts=100 --zca --data_path=data --buffer_path=buffer
Please change --dataset=CIFAR10 to CIFAR100 for results on CIFAR100. "--zca" indicate with zca, simply remove it for cases without zca
python buffer_ATT.py --dataset=Tiny --model=ConvNetD4 --train_epochs=50 --num_experts=100 --data_path=data --buffer_path=buffer
python buffer_ATT.py --dataset=ImageNet --subset=imagenette --model=ConvNetD5 --train_epochs=50 --num_experts=100 --data_path=--data_path=data --buffer_path=buffer
Please change augment "--subset=imagenette" to imagenette, imagewoof, imagefruit, imagemeow.
After creating and storing the expert trajectories, we will distill synthetic dataset with the following parameters.
We used the following command in experiments:
IPC = 1:
python3 distill_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --ipc=1 --zca --syn_steps=80 --expert_epochs=4 --max_start_epoch=2 --lr_img=1e2 --lr_lr=1e-07 --lr_teacher=1e-2 --ema_decay=0.9999 --eval_it=100 --Iteration=6000 --buffer_path=buffer --data_path=data
IPC = 10:
python3 distill_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --ipc=10 --zca --syn_steps=30 --expert_epochs=2 --max_start_epoch=20 --lr_img=1e04 --lr_lr=1e-04 --lr_teacher=1e-02 --ema_decay=0.9999 --eval_it=100 --Iteration=60000 --buffer_path=buffer --data_path=data
IPC = 50
python3 distill_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --ipc=50 --syn_steps=50 --expert_epochs=2 --max_start_epoch=40 --lr_img=1e02 --lr_lr=1e-05 --lr_teacher=1e-03 --ema_decay=0.999 --eval_it=1000 --Iteration=1000000 --buffer_path=buffer --data_path=data
IPC = 1:
python3 distill_ATT.py --dataset=CIFAR100 --model=ConvNetD3 --ipc=1 --zca --syn_steps=50 --expert_epochs=4 --max_start_epoch=20 --lr_img=500 --lr_lr=1e-05 --lr_teacher=1e-02 --ema_decay=0.9995 --eval_it=10 --Iteration=3000 --buffer_path=buffer --data_path=data
IPC = 10:
python3 distill_ATT.py --dataset=CIFAR100 --model=ConvNetD3 --ipc=10 --syn_steps=20 --expert_epochs=2 --max_start_epoch=40 --lr_img=1e3 --lr_lr=1e-05 --lr_teacher=1e-2 --ema_decay=0.9995 --eval_it=5000 --Iteration=200000 --buffer_path=buffer --data_path=data
IPC = 50
python3 distill_ATT.py --dataset=CIFAR100 --model=ConvNetD3 --ipc=50 --zca --syn_steps=80 --expert_epochs=2 --max_start_epoch=40 --lr_img=1e3 --lr_lr=1e-05 --lr_teacher=1e-2 --ema_decay=0.999 --eval_it=100000 --Iteration=100000 --batch_syn=1200 --buffer_path=buffer --data_path=data
python3 distill_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --ipc=10 --zca --syn_steps=30 --expert_epochs=2 --max_start_epoch=20 --lr_img=1e04 --lr_lr=1e-04 --lr_teacher=1e-02 --ema_decay=0.9999 --eval_it=100 --Iteration=60000 --buffer_path=buffer --data_path=data --eval_mode='M'
python3 distill_ATT.py --dataset=CIFAR10 --model=ConvNetD3 --ipc=50 --syn_steps=30 --expert_epochs=2 --max_start_epoch=40 --lr_img=1e02 --lr_lr=1e-05 --lr_teacher=1e-03 --ema_decay=0.999 --eval_it=1000 --Iteration=1000000 --buffer_path=buffer --data_path=data --eval_mode='M'
You can find the following table of results for different dataset as a reference to reproduce.
Cross-architecture performance:
https://github.com/AngusDujw/FTD-distillation/tree/main?tab=readme-ov-file
https://github.com/GeorgeCazenavette/mtt-distillation
https://github.com/VICO-UoE/DatasetCondensation
https://github.com/kuangliu/pytorch-cifar
https://github.com/BIGBALLON/CIFAR-ZOO
https://gist.github.com/rwightman/f2d3849281624be7c0f11c85c87c1598