Skip to content

Commit

Permalink
add back 2x2 AR params, update train script, update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
psandovalsegura committed Jun 18, 2022
1 parent 19412c7 commit 629f27e
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 12 deletions.
66 changes: 60 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,29 +11,83 @@ Code for the paper [Autoregressive Perturbations for Data Poisoning](http://arxi
## Clean-up in progress!
This repo will be completely updated by 6/19/22.

## Finding AR process coefficients
### Generating AR perturbations

See **notebooks/Generate-AR-Perturbations-from-Coefficients.ipynb** for an example of how to load AR coefficients and generate an AR perturbation of a given size and norm.

In summary, after loading some AR coefficients, we can call the `generate` function of `ARProcessPerturb3Channel`:
```
# Load coefficients
coefficients = torch.load(os.path.join(repo_dir, 'params-classes-10-mr-10.pt'))
# Use first set of coefficients, for example
ar = ARProcessPerturb3Channel(b=coefficients[0])
# Generate a size (3, 32, 32) perturbation, after cropping a larger (36, 36) perturbation
perturbation, _ = ar.generate(size=(36,36), eps=1.0, crop=4, p=2)
```

The resulting `perturbation` can then be additively applied directly to an image of shape (3,32,32) because the perturbation is of size 1.0 in L2.

### Finding AR process coefficients

To find a set of 10 AR processes, run:

```
python autoregressive_param_finder.py --total=10 --required_nm_response=10 --gen_norm_upper_bound=50
```

This command will save a file named `params-classes-10-mr-10.pt` using `torch.save`. The format will be identical to that of `RANDOM_3C_AR_PARAMS_RNMR_10` within **autoregressive_params.py**, a list of `torch.tensor`. Additional information can be found in Appendix A.3.
This command will save a file named `params-classes-10-mr-10.pt` using `torch.save`. The format will be identical to that of `RANDOM_3C_AR_PARAMS_RNMR_10` within **autoregressive_params.py**, a list of `torch.tensor`. Additional information can be found in [Appendix A.3](http://arxiv.org/abs/2206.03693).

### Creating a CIFAR-10 poison

Before creating a poison using our script, update `CIFAR_PATH` (and other paths, as required) in **create_ar_poisons_3channel.py** with the location of your CIFAR data. Then, you can create an AR CIFAR-10 poison by calling:

```
python create_ar_poisons_3channel.py ${YOUR_POISON_NAME} --epsilon 1.0 --p_norm 2
```

By default, the code uses params from **autoregressive_params.py**, but you can change this behavior if you like. The script also has support for SVHN, STL, and CIFAR-100.

### Training a model on a poison

## Generating AR perturbations
We provide a number of models, borrowed from the [pytorch-cifar](https://github.com/kuangliu/pytorch-cifar) repo. To train a model on clean CIFAR-10:

See **notebooks/Generate-AR-Perturbations-from-Coefficients.ipynb** for an example of how to load AR coefficients and generate an AR perturbation of a given size and norm.
```
python train.py misc.project_name=${PROJECT_NAME} misc.run_name=${RUN_NAME} train.batch_size=128 train.augmentations_key="none"
```

## Training a model on a poison
To train a model on an AR CIFAR-10 poison:
```
python train.py misc.project_name=${PROJECT_NAME} misc.run_name=${RUN_NAME} train.adversarial_poison_path=${YOUR_POISON_PATH} train.batch_size=128 train.augmentations_key=${AUG}
```
Note that in this command, we specify `train.adversarial_poison_path` to override the config within **config/base.yaml**, and load a poison.

You can set `AUG` to either "none", "cutout", "cutmix" or "mixup".

This training script uses the `WandbLogger` from [PyTorch Lightning](https://pytorch-lightning.readthedocs.io/en/latest/), so if you use [Weights and Biases](https://wandb.ai/), you can use their online portal to analyze training curves.

### Try and train a model yourself!
### Train your own network on our poisons!
We release our AR poisons as Zip files containing PNG images for easy viewing via [Google Drive](https://drive.google.com/drive/folders/1ze0cKAXNcPRkC0TMmObMj-g7Gspp1DpL?usp=sharing). This includes a
- CIFAR-10 AR Poison: ar-cifar-10.zip
- CIFAR-100 AR Poison: ar-cifar-100.zip
- SVHN AR Poison: ar-svhn.zip
- STL AR Poison: ar-stl.zip

After unzipping, these poisons can be loaded using `AdversarialPoison`, a subclass of `torch.utils.data.Dataset`. A model which trains on our AR poisons is unable to generalize to the (clean) test set.

### AR Perfect Model

The simple CNN which can perfectly classify AR perturbations is in **autoregressive_perfect_model.py**. More information can be found in [Appendix A.2](http://arxiv.org/abs/2206.03693) of the paper. This code was from an earlier version of our code where one AR process was repeated across each of the three channels (as opposed to using a different set of coefficients for each of 3 channels). Early in our work, we used terms from convergent series, and manually specified them in `ALL_2x2_AR_PARAMS`. To avoid confusion, we don't provide actionable code for `PerfectARModel`, but feel free to modify the code and use your own AR coefficients.

### Citation

If you find this work useful for your research, please cite our paper:
```
@article{sandoval2022autoregressive,
title={Autoregressive Perturbations for Data Poisoning},
author={Sandoval-Segura, Pedro and Singla, Vasu and Geiping, Jonas and Goldblum, Micah and Goldstein, Tom and Jacobs, David W},
journal={arXiv preprint arXiv:2206.03693},
year={2022}
}
```
71 changes: 70 additions & 1 deletion autoregressive_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,4 +112,73 @@
[-0.0165, 0.1160, 0.1356],
[ 0.1772, 0.0321, -0.0000]]])]

RANDOM_100CLASS_3C_AR_PARAMS_RNMR_3 = torch.load('/cfarhomes/psando/Documents/autoregressive-poisoning/params-classes-100-mr-3.pt')
RANDOM_100CLASS_3C_AR_PARAMS_RNMR_3 = torch.load('/cfarhomes/psando/Documents/autoregressive-poisoning/params-classes-100-mr-3.pt')

# The coefficients below are not meant to be used for poisoning
# but instead are used to illustrate the simplicity of PerfectARModel
# found within autoregressive_perfect_model.py

geo_a1_r12 = np.array([[(-1/2), (-1/4), (-1/8)],
[(-1/16), (-1/32), (-1/64)],
[(-1/128), (-1/256), 0]])

geo_a62_r17 = np.array([[(-6.2/7), (-6.2/49), (-6.2/343)],
[(-6.2/2401), (-6.2/16807), (-6.2/117649)],
[(-6.2/823543), (-6.2/5764801), 0]])

geo_a51_r16 = np.array([[(-5.1/6), (-5.1/36), (-5.1/216)],
[(-5.1/1296), (-5.1/7776), (-5.1/46656)],
[(-5.1/279936), (-5.1/1679616), 0]])

fibonacci = np.array([[(-1/2), (-1/3), (-1/5)],
[(-1/8), (-1/13), (-1/21)],
[(-1/34), (-1/55), 0]])

geo_a2_r13 = np.array([[(-2/3), (-2/9), (-2/27)],
[(-2/81), (-2/243), (-2/729)],
[(-2/2187), (-2/6561), 0]])

geo_a12_r12 = np.array([[(-1.2/2), (-1.2/4), (-1.2/8)],
[(-1.2/16), (-1.2/32), (-1.2/64)],
[(-1.2/128), (-1.2/256), 0]])

geo_a34_r14 = np.array([[(-3.4/4), (-3.4/16), (-3.4/64)],
[(-3.4/256), (-3.4/1024), (-3.4/4096)],
[(-3.4/16384), (-3.4/65536), 0]])

geo_a15_r12 = np.array([[(-1.5/2), (-1.5/4), (-1.5/8)],
[(-1.5/16), (-1.5/32), (-1.5/64)],
[(-1.5/128), (-1.5/256), 0]])

geo_a25_r13 = np.array([[(-2.5/3), (-2.5/9), (-2.5/27)],
[(-2.5/81), (-2.5/243), (-2.5/729)],
[(-2.5/2187), (-2.5/6561), 0]])

geo_45_r15 = np.array([[(-4.5/5), (-4.5/25), (-4.5/125)],
[(-4.5/625), (-4.5/3125), (-4.5/15625)],
[(-4.5/78125), (-4.5/390625), 0]])

ALL_2x2_AR_PARAMS = {
'geo_a1_r12': geo_a1_r12,
'geo_a62_r17': geo_a62_r17,
'geo_a51_r16': geo_a51_r16,
'fibonacci': fibonacci,
'geo_a2_r13': geo_a2_r13,
'geo_a12_r12': geo_a12_r12,
'geo_a34_r14': geo_a34_r14,
'geo_a15_r12': geo_a15_r12,
'geo_a25_r13': geo_a25_r13,
'geo_45_r15': geo_45_r15
}

ALL_2x2_AR_FILTERS = {}
for key, value in ALL_2x2_AR_PARAMS.items():
# the matching filter is almost identical to the AR
# process parameters, but has a -1 as the last entry
filter = np.copy(value)
filter[2][2] = -1

# every filter should be normalized so that the sum of
# the coefficients is 1
filter = filter / np.sum(filter)
ALL_2x2_AR_FILTERS[key] = filter
2 changes: 1 addition & 1 deletion autoregressive_perfect_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn
import numpy as np
from autoregressive_models import ALL_2x2_AR_FILTERS
from autoregressive_params import ALL_2x2_AR_FILTERS

class PerfectARModel(nn.Module):
"""A simple CNN with 10 filters of size 3x3 followed by
Expand Down
4 changes: 0 additions & 4 deletions lightning_modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +0,0 @@
from lightning_cifar10 import LitCIFAR10Model
from lightning_cifar100 import LitCIFAR100Model
from lightning_stl10 import LitSTLModel
from lightning_svhn import LitSVHNModel

0 comments on commit 629f27e

Please sign in to comment.