Skip to content

Commit 4945b77

Browse files
csferngtensorflow-copybara
authored andcommitted
Support PGD adversarial regularization in Keras and Estimator APIs.
PiperOrigin-RevId: 314612195
1 parent 704d52b commit 4945b77

9 files changed

+153
-57
lines changed

neural_structured_learning/configs/configs.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,19 @@ class AdvNeighborConfig(object):
5858
corresponding feature.
5959
clip_value_max: maximum value to clip the feature after perturbation. (See
6060
`clip_value_min` for the structure and shape limitations.)
61-
iterations: number of iterations to run the attack for. Defaults to a single
62-
step, used for the Fast Gradient Sign Method (FGSM) attack.
63-
epsilon: Defines radius of the epsilon ball to project back to.
61+
pgd_iterations: number of attack iterations for Projected Gradient Descent
62+
(PGD) attack. Defaults to 1, which resembles the Fast Gradient Sign Method
63+
(FGSM) attack.
64+
pgd_epsilon: radius of the epsilon ball to project back to. Only used in
65+
Projected Gradient Descent (PGD) attack.
6466
"""
6567
feature_mask = attr.ib(default=None)
6668
adv_step_size = attr.ib(default=0.001)
6769
adv_grad_norm = attr.ib(converter=NormType, default='l2')
6870
clip_value_min = attr.ib(default=None)
6971
clip_value_max = attr.ib(default=None)
70-
iterations = attr.ib(default=1) # 1 is the FGSM attack.
71-
epsilon = attr.ib(default=None)
72+
pgd_iterations = attr.ib(default=1) # 1 is the FGSM attack.
73+
pgd_epsilon = attr.ib(default=None)
7274

7375

7476
@attr.s
@@ -91,7 +93,9 @@ def make_adv_reg_config(
9193
adv_step_size=attr.fields(AdvNeighborConfig).adv_step_size.default,
9294
adv_grad_norm=attr.fields(AdvNeighborConfig).adv_grad_norm.default,
9395
clip_value_min=attr.fields(AdvNeighborConfig).clip_value_min.default,
94-
clip_value_max=attr.fields(AdvNeighborConfig).clip_value_max.default):
96+
clip_value_max=attr.fields(AdvNeighborConfig).clip_value_max.default,
97+
pgd_iterations=attr.fields(AdvNeighborConfig).pgd_iterations.default,
98+
pgd_epsilon=attr.fields(AdvNeighborConfig).pgd_epsilon.default):
9599
"""Creates an `nsl.configs.AdvRegConfig` object.
96100
97101
Args:
@@ -115,6 +119,11 @@ def make_adv_reg_config(
115119
corresponding feature.
116120
clip_value_max: maximum value to clip the feature after perturbation. (See
117121
`clip_value_min` for the structure and shape limitations.)
122+
pgd_iterations: number of attack iterations for Projected Gradient Descent
123+
(PGD) attack. Defaults to 1, which resembles the Fast Gradient Sign Method
124+
(FGSM) attack.
125+
pgd_epsilon: radius of the epsilon ball to project back to. Only used in
126+
Projected Gradient Descent (PGD) attack.
118127
119128
Returns:
120129
An `nsl.configs.AdvRegConfig` object.
@@ -126,7 +135,9 @@ def make_adv_reg_config(
126135
adv_step_size=adv_step_size,
127136
adv_grad_norm=adv_grad_norm,
128137
clip_value_min=clip_value_min,
129-
clip_value_max=clip_value_max))
138+
clip_value_max=clip_value_max,
139+
pgd_iterations=pgd_iterations,
140+
pgd_epsilon=pgd_epsilon))
130141

131142

132143
class AdvTargetType(enum.Enum):

neural_structured_learning/estimator/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ py_test(
5252
srcs_version = "PY2AND3",
5353
deps = [
5454
":estimator",
55+
# package absl/testing:parameterized
5556
"//neural_structured_learning/configs",
5657
# package numpy
5758
# package tensorflow

neural_structured_learning/estimator/adversarial_regularization.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -91,24 +91,27 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
9191
# If no 'params' is passed, then it is possible for base_model_fn not to
9292
# accept a 'params' argument. See documentation for tf.estimator.Estimator
9393
# for additional context.
94-
# pylint: disable=g-long-lambda
95-
spec_fn = ((lambda features: base_model_fn(
96-
features, labels, mode, params, config)) if params else (
97-
lambda features: base_model_fn(features, labels, mode, config)))
94+
base_args = [mode, params, config] if params else [mode, config]
95+
spec_fn = lambda feature, label: base_model_fn(feature, label, *base_args)
9896

99-
original_spec = spec_fn(features)
100-
101-
print("ORIGINAL_SPEC", original_spec)
97+
original_spec = spec_fn(features, labels)
10298

10399
# Adversarial regularization only happens in training.
104100
if mode != tf.estimator.ModeKeys.TRAIN:
105101
return original_spec
106102

107-
adv_neighbor, _ = nsl_lib.gen_adv_neighbor(features, original_spec.loss,
108-
adv_config.adv_neighbor_config)
103+
adv_neighbor, _ = nsl_lib.gen_adv_neighbor(
104+
features,
105+
original_spec.loss,
106+
adv_config.adv_neighbor_config,
107+
# The pgd_model_fn is a dummy identity function since loss is
108+
# directly available from spec_fn.
109+
pgd_model_fn=lambda features: features,
110+
pgd_loss_fn=lambda labels, features: spec_fn(features, labels).loss,
111+
pgd_labels=labels)
109112

110113
# Runs the base model again to compute loss on adv_neighbor.
111-
adv_spec = spec_fn(adv_neighbor)
114+
adv_spec = spec_fn(adv_neighbor, labels)
112115

113116
final_loss = original_spec.loss + adv_config.multiplier * adv_spec.loss
114117

neural_structured_learning/estimator/adversarial_regularization_test.py

+21-10
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import shutil
2222
import tempfile
2323

24+
from absl.testing import parameterized
2425
import neural_structured_learning.configs as nsl_configs
2526
import neural_structured_learning.estimator as nsl_estimator
26-
2727
import numpy as np
2828
import tensorflow as tf
2929

@@ -43,7 +43,7 @@ def input_fn():
4343
return input_fn
4444

4545

46-
class AdversarialRegularizationTest(tf.test.TestCase):
46+
class AdversarialRegularizationTest(tf.test.TestCase, parameterized.TestCase):
4747

4848
def setUp(self):
4949
super(AdversarialRegularizationTest, self).setUp()
@@ -79,19 +79,25 @@ def test_adversarial_wrapper_not_affecting_predictions(self):
7979
predicted_scores = [x['predictions'] for x in predictions]
8080
self.assertAllClose([[3.0], [4.0]], predicted_scores)
8181

82+
@parameterized.named_parameters([
83+
('fgsm', 0.1, 1, None),
84+
('pgd', 0.1, 3, 0.25),
85+
])
8286
@test_util.run_v1_only('Requires tf.GraphKeys')
83-
def test_adversarial_wrapper_adds_regularization(self):
87+
def test_adversarial_wrapper_adds_regularization(self, adv_step_size,
88+
pgd_iterations, pgd_epsilon):
8489
# base model: y = w*x+b = 4*x1 + 3*x2 + 2
8590
weight = np.array([[4.0], [3.0]], dtype=np.float32)
8691
bias = np.array([2.0], dtype=np.float32)
8792
x0, y0 = np.array([[1.0, 1.0]]), np.array([8.0])
88-
adv_step_size = 0.1
8993
learning_rate = 0.01
9094

9195
base_est = self.build_linear_regressor(weight=weight, bias=bias)
9296
adv_config = nsl_configs.make_adv_reg_config(
9397
multiplier=1.0, # equal weight on original and adv examples
94-
adv_step_size=adv_step_size)
98+
adv_step_size=adv_step_size,
99+
pgd_iterations=pgd_iterations,
100+
pgd_epsilon=pgd_epsilon)
95101
adv_est = nsl_estimator.add_adversarial_regularization(
96102
base_est,
97103
optimizer_fn=lambda: tf.train.GradientDescentOptimizer(learning_rate),
@@ -104,11 +110,16 @@ def test_adversarial_wrapper_adds_regularization(self):
104110
orig_grad_w = 2 * (orig_pred - y0) * x0.T # [[2.0], [2.0]]
105111
orig_grad_b = 2 * (orig_pred - y0).reshape((1,)) # [2.0]
106112
grad_x = 2 * (orig_pred - y0) * weight.T # [[8.0, 6.0]]
107-
perturbation = adv_step_size * grad_x / np.linalg.norm(grad_x)
108-
x_adv = x0 + perturbation # [[1.08, 1.06]]
109-
adv_pred = np.dot(x_adv, weight) + bias # [9.5]
110-
adv_grad_w = 2 * (adv_pred - y0) * x_adv.T # [[3.24], [3.18]]
111-
adv_grad_b = 2 * (adv_pred - y0).reshape((1,)) # [3.0]
113+
# Gradient direction is independent of x, so perturbing for multiple
114+
# iterations is the same as scaling the perturbation.
115+
perturbation_magnitude = pgd_iterations * adv_step_size
116+
if pgd_epsilon is not None:
117+
perturbation_magnitude = np.minimum(perturbation_magnitude, pgd_epsilon)
118+
perturbation = perturbation_magnitude * grad_x / np.linalg.norm(grad_x)
119+
x_adv = x0 + perturbation # fgm: [[1.08, 1.06]]; pgd: [[1.20, 1.15]]
120+
adv_pred = np.dot(x_adv, weight) + bias # fgm: [9.5]; pgd: [10.25]
121+
adv_grad_w = 2 * (adv_pred - y0) * x_adv.T # fgm: [[3.24], [3.18]]
122+
adv_grad_b = 2 * (adv_pred - y0).reshape((1,)) # fgm: [3.0]; pgd: [4.5]
112123

113124
new_bias = bias - learning_rate * (orig_grad_b + adv_grad_b)
114125
new_weight = weight - learning_rate * (orig_grad_w + adv_grad_w)

neural_structured_learning/keras/adversarial_regularization.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,10 @@ def adversarial_loss(features,
138138
features,
139139
labeled_loss,
140140
config=adv_config.adv_neighbor_config,
141-
gradient_tape=gradient_tape)
141+
gradient_tape=gradient_tape,
142+
pgd_model_fn=model,
143+
pgd_loss_fn=functools.partial(loss_fn, sample_weights=sample_weights),
144+
pgd_labels=labels)
142145
adv_output = model(adv_input)
143146
if sample_weights is not None:
144147
adv_sample_weights = tf.math.multiply(sample_weights, adv_sample_weights)
@@ -713,7 +716,13 @@ def perturb_on_batch(self, x, **config_kwargs):
713716
config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None}
714717
config = attr.evolve(self.adv_config.adv_neighbor_config, **config_kwargs)
715718
adv_inputs, _ = nsl_lib.gen_adv_neighbor(
716-
inputs, labeled_loss, config=config, gradient_tape=tape)
719+
inputs,
720+
labeled_loss,
721+
config=config,
722+
gradient_tape=tape,
723+
pgd_model_fn=self._call_base_model,
724+
pgd_loss_fn=self._compute_total_loss,
725+
pgd_labels=labels)
717726

718727
if tf.executing_eagerly():
719728
# Converts `Tensor` objects to NumPy arrays and keeps other objects (e.g.

neural_structured_learning/keras/adversarial_regularization_test.py

+63
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,45 @@ def call(self, inputs):
542542

543543
self.assertIn('label', model.seen_input_keys)
544544

545+
@parameterized.named_parameters([
546+
('sequential', build_linear_keras_sequential_model),
547+
('sequential_no_input_layer',
548+
build_linear_keras_sequential_model_no_input_layer),
549+
('functional', build_linear_keras_functional_model),
550+
('subclassed', build_linear_keras_subclassed_model),
551+
])
552+
def test_train_pgd(self, model_fn):
553+
w = np.array([[4.0], [-3.0]])
554+
x0 = np.array([[2.0, 3.0]])
555+
y0 = np.array([[0.0]])
556+
adv_multiplier = 0.2
557+
adv_step_size = 0.01
558+
learning_rate = 0.01
559+
pgd_iterations = 3
560+
pgd_epsilon = 2.5 * adv_step_size
561+
adv_config = configs.make_adv_reg_config(
562+
multiplier=adv_multiplier,
563+
adv_step_size=adv_step_size,
564+
adv_grad_norm='infinity',
565+
pgd_iterations=pgd_iterations,
566+
pgd_epsilon=pgd_epsilon)
567+
y_hat = np.dot(x0, w)
568+
# The adversarial perturbation is constant across PGD iterations.
569+
x_adv = x0 + pgd_epsilon * np.sign((y_hat - y0) * w.T)
570+
y_hat_adv = np.dot(x_adv, w)
571+
grad_w_labeled_loss = 2. * (y_hat - y0) * x0.T
572+
grad_w_adv_loss = adv_multiplier * 2. * (y_hat_adv - y0) * x_adv.T
573+
w_new = w - learning_rate * (grad_w_labeled_loss + grad_w_adv_loss)
574+
575+
inputs = {'feature': tf.constant(x0), 'label': tf.constant(y0)}
576+
model = model_fn(input_shape=(2,), weights=w)
577+
adv_model = adversarial_regularization.AdversarialRegularization(
578+
model, label_keys=['label'], adv_config=adv_config)
579+
adv_model.compile(tf.keras.optimizers.SGD(learning_rate), loss='MSE')
580+
adv_model.fit(x=inputs, batch_size=1, steps_per_epoch=1)
581+
582+
self.assertAllClose(w_new, tf.keras.backend.get_value(model.weights[0]))
583+
545584
def test_evaluate_binary_classification_metrics(self):
546585
# multi-label binary classification model
547586
w = np.array([[4.0, 1.0, -5.0], [-3.0, 1.0, 2.0]])
@@ -633,6 +672,30 @@ def test_perturb_on_batch_custom_config(self):
633672
self.assertAllClose(x_adv, adv_inputs['feature'])
634673
self.assertAllClose(y0, adv_inputs['label'])
635674

675+
@parameterized.named_parameters([
676+
('sequential', build_linear_keras_sequential_model),
677+
('sequential_no_input_layer',
678+
build_linear_keras_sequential_model_no_input_layer),
679+
('functional', build_linear_keras_functional_model),
680+
('subclassed', build_linear_keras_subclassed_model),
681+
])
682+
def test_perturb_on_batch_pgd(self, model_fn):
683+
w, x0, y0, lr, adv_config, _ = self._set_up_linear_regression()
684+
pgd_epsilon = 4.5 * adv_config.adv_neighbor_config.adv_step_size
685+
adv_config.adv_neighbor_config.pgd_iterations = 5
686+
adv_config.adv_neighbor_config.pgd_epsilon = pgd_epsilon
687+
inputs = {'feature': x0, 'label': y0}
688+
model = model_fn(input_shape=(2,), weights=w)
689+
adv_model = adversarial_regularization.AdversarialRegularization(
690+
model, label_keys=['label'], adv_config=adv_config)
691+
adv_model.compile(optimizer=tf.keras.optimizers.SGD(lr), loss=['MSE'])
692+
adv_inputs = adv_model.perturb_on_batch(inputs)
693+
694+
y_hat = np.dot(x0, w)
695+
x_adv = x0 + pgd_epsilon * np.sign((y_hat - y0) * w.T)
696+
self.assertAllClose(x_adv, adv_inputs['feature'])
697+
self.assertAllClose(y0, adv_inputs['label'])
698+
636699

637700
if __name__ == '__main__':
638701
tf.test.main()

neural_structured_learning/lib/adversarial_neighbor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def gen_neighbor(self, input_features, pgd_labels=None):
202202
logging.log_first_n(logging.WARNING,
203203
'Cannot perturb non-Tensor input: %s', 1, sparse_keys)
204204
dense_features = dense_original_features
205-
for t in range(self._adv_config.iterations):
205+
for t in range(self._adv_config.pgd_iterations):
206206
keyed_grads = self._compute_gradient(loss, dense_features, gradient_tape)
207207
masked_grads = {
208208
key: utils.apply_feature_mask(grad, feature_masks.get(key, None))
@@ -221,8 +221,8 @@ def gen_neighbor(self, input_features, pgd_labels=None):
221221
# Only include features for which perturbation occurred. There is
222222
# nothing to project for features without perturbations.
223223
diff[key] = dense_features[key] + perturb - dense_original_features[key]
224-
if self._adv_config.epsilon is not None:
225-
bounded_diff = utils.project_to_ball(diff, self._adv_config.epsilon,
224+
if self._adv_config.pgd_epsilon is not None:
225+
bounded_diff = utils.project_to_ball(diff, self._adv_config.pgd_epsilon,
226226
self._adv_config.adv_grad_norm)
227227
else:
228228
bounded_diff = diff
@@ -239,7 +239,7 @@ def gen_neighbor(self, input_features, pgd_labels=None):
239239
feature_min.get(key, None), feature_max.get(key, None)))
240240

241241
# Update for the next iteration.
242-
if t < self._adv_config.iterations - 1:
242+
if t < self._adv_config.pgd_iterations - 1:
243243
inputs_t = self._decompose_as(input_features, adv_neighbor)
244244
# Compute the new loss to calculate gradients with.
245245
features = self._compose_as_dict(inputs_t)

0 commit comments

Comments
 (0)