Skip to content

Commit e50305f

Browse files
csferngtensorflow-copybara
authored andcommitted
Fix batch stat updates for adversarial-regularized estimator in TensorFlow 1.x
PiperOrigin-RevId: 318535128
1 parent 77245bb commit e50305f

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

neural_structured_learning/estimator/adversarial_regularization.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,15 @@ def adv_model_fn(features, labels, mode, params=None, config=None):
130130
else:
131131
optimizer = optimizer_fn()
132132

133-
final_train_op = optimizer.minimize(
133+
train_op = optimizer.minimize(
134134
loss=final_loss, global_step=tf.compat.v1.train.get_global_step())
135135

136-
return original_spec._replace(loss=final_loss, train_op=final_train_op)
136+
update_ops = tf.compat.v1.get_collection(
137+
tf.compat.v1.GraphKeys.UPDATE_OPS)
138+
if update_ops:
139+
train_op = tf.group(train_op, *update_ops)
140+
141+
return original_spec._replace(loss=final_loss, train_op=train_op)
137142

138143
# Replaces the model_fn while keeps other fields/methods in the estimator.
139144
estimator._model_fn = adv_model_fn # pylint: disable=protected-access

neural_structured_learning/estimator/adversarial_regularization_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,27 @@ def test_adversarial_wrapper_adds_regularization(self, adv_step_size,
126126
self.assertAllClose(new_bias, adv_est.get_variable_value(BIAS_VARIABLE))
127127
self.assertAllClose(new_weight, adv_est.get_variable_value(WEIGHT_VARIABLE))
128128

129+
@test_util.run_v1_only('Requires tf.train.GradientDescentOptimizer')
130+
def test_adversarial_wrapper_saving_batch_statistics(self):
131+
x0, y0 = np.array([[0.9, 0.1], [0.2, 0.8]]), np.array([1, 0])
132+
input_fn = single_batch_input_fn({FEATURE_NAME: x0}, y0)
133+
fc = tf.feature_column.numeric_column(FEATURE_NAME, shape=[2])
134+
base_est = tf.estimator.DNNClassifier(
135+
hidden_units=[4],
136+
feature_columns=[fc],
137+
model_dir=self.model_dir,
138+
batch_norm=True)
139+
adv_est = nsl_estimator.add_adversarial_regularization(
140+
base_est,
141+
optimizer_fn=lambda: tf.train.GradientDescentOptimizer(0.005))
142+
143+
adv_est.train(input_fn=input_fn, steps=1)
144+
moving_mean = adv_est.get_variable_value(
145+
'dnn/hiddenlayer_0/batchnorm_0/moving_mean')
146+
moving_variance = adv_est.get_variable_value(
147+
'dnn/hiddenlayer_0/batchnorm_0/moving_variance')
148+
self.assertNotAllClose(moving_mean, np.zeros(moving_mean.shape))
149+
self.assertNotAllClose(moving_variance, np.ones(moving_variance.shape))
129150

130151
if __name__ == '__main__':
131152
tf.test.main()

0 commit comments

Comments
 (0)