Skip to content

Commit

Permalink
add code and model
Browse files Browse the repository at this point in the history
  • Loading branch information
cxmscb committed Jan 17, 2019
0 parents commit fec0e5c
Show file tree
Hide file tree
Showing 32 changed files with 1,826 additions and 0 deletions.
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# ATDA
This repository contains code to reproduce results from the paper:

**Improving the Generalization of Adversarial Training with Domain Adaptation (ICLR 2019)**

openreview report: https://openreview.net/forum?id=SyfIfnC5Ym

###### REQUIREMENTS

The code was tested with Python 3.6.5, Tensorflow 1.8.0, Keras 2.12, Keras_contrib 0.0.2, Torchvision 0.2.1 and Numpy 1.14.3.

###### EXPERIMENTS

We use Adversarial Training (on FGSM) with Domain Adaptaion to train a main model (modelZ) for CIFAR-10 (default).

```
python -m train_atda models/modelZ_atda --type=0
```

In addition, we use Adversarial Training (on the noisy PGD) with Domain Adaptaion to train a main model (modelZ) for CIFAR-10 (default).

```
python -m train_atda_npgd models/modelZ_atda --type=0
```

Then, we use Normal Training to train a model (modelC) for CIFAR-10 (default).

```
python -m train models/modelC --type=3
```

To use Original/ Standard Adversarial Training to train a main model:

```
python -m train_adv models/modelZ_adv --type=0
```

To use Ensemble Adversarial Training to train a main model:

```
# First train pre-trained models:
python -m train models/modelA --type=1
python -m train models/modelB --type=2
# use Ensemble Adversarial Training method to train with pre-trained models
python -m train_adv models/modelZ_ens models/modelA models/modelB --type=0
```

The accuracy of the models on the Fashion MNIST test set can be computed using:

```
python -m simple_eval test [model(s)]
```

To evaluate robustness to various attacks, we use:

```
python -m simple_eval [attack] [source_model] [target_model(s)] [--parameters (opt)]
```

The attack can be:

| Attack | Description | Parameters |
| ------ | ------------------------- | ------------------------------------------------------------ |
| fgs | Standard FGSM | *eps* (the norm of the perturbation) |
| rfgs | RAND+FGSM | *eps* (the norm of the total perturbation); *alpha* (the norm of the random perturbation) |
| pgd | The iterative FGSM | *eps* (the norm of the perturbation); *steps* (the number of iterative FGSM steps); alpha = eps/10.0 |
| mim | Momentum Iterative Method | The parameter is fixed in the function *momentum_fgs* of the *fgs.py*. |

###### Acknowledgments

Code refer heavily to: [Ensemble Adversarial Training](https://github.com/cxmscb/ensemble-adv-training)
62 changes: 62 additions & 0 deletions attack_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import numpy as np
import keras.backend as K
import tensorflow as tf
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS


def linf_loss(X1, X2):
return np.max(np.abs(X1 - X2), axis=(1, 2, 3))


def gen_adv_loss(logits, y, loss='logloss', mean=False):
"""
Generate the loss function.
"""

if loss == 'training':
# use the model's output instead of the true labels to avoid
# label leaking at training time
y = K.cast(K.equal(logits, K.max(logits, 1, keepdims=True)), "float32")
y = y / K.sum(y, 1, keepdims=True)
out = K.categorical_crossentropy(y, logits, from_logits=True)
elif loss == 'min_training':
y = K.cast(K.equal(logits, K.min(logits, 1, keepdims=True)), "float32")
y = y / K.sum(y, 1, keepdims=True)
out = K.categorical_crossentropy(y, logits, from_logits=True)
elif loss == 'logloss':
out = K.categorical_crossentropy(y, logits, from_logits=True)
else:
raise ValueError("Unknown loss: {}".format(loss))

if mean:
out = K.mean(out)
else:
out = K.sum(out)
return out


def gen_grad(x, logits, y, loss='logloss'):
"""
Generate the gradient of the loss function.
"""

adv_loss = gen_adv_loss(logits, y, loss)

# Define gradient of loss wrt input
grad = K.gradients(adv_loss, [x])[0]
return grad



def get_grad_L1(x, logits):
x_shape = x.get_shape().as_list()
dims = x_shape[1]*x_shape[2]*x_shape[3]

adv_loss = gen_adv_loss(logits, None, loss='training')
grad = K.gradients(adv_loss, [x])[0]

flatten_grad = K.reshape(grad, shape=[-1, dims])
L1_grad = K.sum(K.abs(flatten_grad), axis=-1)
print(L1_grad)
return L1_grad
194 changes: 194 additions & 0 deletions cifar10.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
from keras import backend as K
import keras
from keras.datasets import cifar10
from keras.models import Sequential, model_from_json
from keras.layers import Dense, Dropout, Activation, Flatten, Input, GlobalAveragePooling2D, Lambda
from keras.layers import Convolution2D, MaxPooling2D, ZeroPadding2D, Conv2D
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils
from keras_contrib.layers.normalization import GroupNormalization
import numpy as np
K.set_image_data_format('channels_first')


from tensorflow.python.platform import flags
FLAGS = flags.FLAGS


def set_flags(batch_size):
flags.DEFINE_integer('BATCH_SIZE', batch_size, 'Size of training batches')

flags.DEFINE_integer('NUM_CLASSES', 10, 'Number of classification classes')
flags.DEFINE_integer('IMAGE_ROWS', 32, 'Input row dimension')
flags.DEFINE_integer('IMAGE_COLS', 32, 'Input column dimension')
flags.DEFINE_integer('NUM_CHANNELS', 3, 'Input depth dimension')



def load_data(one_hot=True):
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.0
X_test /= 255.0
print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

print("Loaded CIFAR-10 dataset.")

if one_hot:
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, FLAGS.NUM_CLASSES).astype(np.float32)
y_test = np_utils.to_categorical(y_test, FLAGS.NUM_CLASSES).astype(np.float32)

return X_train, y_train, X_test, y_test



def modelZ():
model = Sequential()
model.add(Conv2D(96, (3, 3), padding = 'same', input_shape=(FLAGS.NUM_CHANNELS, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS)))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))

model.add(Conv2D(96, (3, 3), padding = 'same'))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Conv2D(96, (3, 3), padding = 'same', strides = 2))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Dropout(0.5))

model.add(Conv2D(192, (3, 3) , padding = 'same'))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Conv2D(192, (3, 3), padding = 'same'))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Conv2D(192, (3, 3), padding = 'same', strides = 2))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Dropout(0.5))

model.add(Conv2D(192, (3, 3), padding = 'same'))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Conv2D(192, (1, 1),padding='valid'))
model.add(GroupNormalization(axis=1))
model.add(Activation('elu'))
model.add(Conv2D(10, (1, 1), padding='valid'))

model.add(GlobalAveragePooling2D())
return model



def modelA():
weight_decay = 0.0001
model = Sequential()

model.add(Conv2D(96, (5, 5), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal", input_shape=(FLAGS.NUM_CHANNELS, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS)))
model.add(Activation('elu'))
model.add(Conv2D(96, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(MaxPooling2D(pool_size=(3, 3),strides=(2,2),padding = 'same'))

model.add(Dropout(0.5))

model.add(Conv2D(192, (5, 5), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(192, (1, 1),padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(MaxPooling2D(pool_size=(3, 3),strides=(2,2),padding = 'same'))

model.add(Dropout(0.5))

model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(256, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(FLAGS.NUM_CLASSES, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))

model.add(GlobalAveragePooling2D())
return model

def modelB():
weight_decay = 0.0001
model = Sequential()

model.add(Conv2D(192, (5, 5), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal", input_shape=(FLAGS.NUM_CHANNELS, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS)))
model.add(Activation('elu'))
model.add(Conv2D(96, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(MaxPooling2D(pool_size=(3, 3),strides=(2,2),padding = 'same'))

model.add(Dropout(0.5))

model.add(Conv2D(192, (5, 5), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(192, (1, 1),padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(MaxPooling2D(pool_size=(3, 3),strides=(2,2),padding = 'same'))

model.add(Dropout(0.5))

model.add(Conv2D(256, (3, 3), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(256, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(Activation('elu'))
model.add(Conv2D(FLAGS.NUM_CLASSES, (1, 1), padding='same', kernel_regularizer=keras.regularizers.l2(weight_decay), kernel_initializer="he_normal"))
model.add(GlobalAveragePooling2D())
return model


def modelC():
model = Sequential()
model.add(Conv2D(96, (3, 3), activation='elu', padding = 'same', input_shape=(FLAGS.NUM_CHANNELS, FLAGS.IMAGE_ROWS, FLAGS.IMAGE_COLS)))
model.add(Dropout(0.2))

model.add(Conv2D(96, (3, 3), activation='elu', padding = 'same'))
model.add(Conv2D(96, (3, 3), activation='elu', padding = 'same', strides = 2))
model.add(Dropout(0.5))

model.add(Conv2D(192, (3, 3), activation='elu', padding = 'same'))
model.add(Conv2D(192, (3, 3), activation='elu', padding = 'same', strides = 2))
model.add(Dropout(0.5))

model.add(Conv2D(256, (3, 3), padding = 'same'))
model.add(Activation('elu'))
model.add(Conv2D(256, (1, 1),padding='valid'))
model.add(Activation('elu'))
model.add(Conv2D(FLAGS.NUM_CLASSES, (1, 1), padding='valid'))

model.add(GlobalAveragePooling2D())

return model


def model_select(type=0):

models = [modelZ, modelA, modelB, modelC]

return models[type]()


def data_flow(X_train):
datagen = ImageDataGenerator()

datagen.fit(X_train)
return datagen


def load_model(model_path, type=0):

try:
with open(model_path+'.json', 'r') as f:
json_string = f.read()
model = model_from_json(json_string)
except IOError:
model = model_select(type=type)

model.load_weights(model_path)
return model
Loading

0 comments on commit fec0e5c

Please sign in to comment.