@@ -126,6 +126,27 @@ def test_adversarial_wrapper_adds_regularization(self, adv_step_size,
126
126
self .assertAllClose (new_bias , adv_est .get_variable_value (BIAS_VARIABLE ))
127
127
self .assertAllClose (new_weight , adv_est .get_variable_value (WEIGHT_VARIABLE ))
128
128
129
+ @test_util .run_v1_only ('Requires tf.train.GradientDescentOptimizer' )
130
+ def test_adversarial_wrapper_saving_batch_statistics (self ):
131
+ x0 , y0 = np .array ([[0.9 , 0.1 ], [0.2 , 0.8 ]]), np .array ([1 , 0 ])
132
+ input_fn = single_batch_input_fn ({FEATURE_NAME : x0 }, y0 )
133
+ fc = tf .feature_column .numeric_column (FEATURE_NAME , shape = [2 ])
134
+ base_est = tf .estimator .DNNClassifier (
135
+ hidden_units = [4 ],
136
+ feature_columns = [fc ],
137
+ model_dir = self .model_dir ,
138
+ batch_norm = True )
139
+ adv_est = nsl_estimator .add_adversarial_regularization (
140
+ base_est ,
141
+ optimizer_fn = lambda : tf .train .GradientDescentOptimizer (0.005 ))
142
+
143
+ adv_est .train (input_fn = input_fn , steps = 1 )
144
+ moving_mean = adv_est .get_variable_value (
145
+ 'dnn/hiddenlayer_0/batchnorm_0/moving_mean' )
146
+ moving_variance = adv_est .get_variable_value (
147
+ 'dnn/hiddenlayer_0/batchnorm_0/moving_variance' )
148
+ self .assertNotAllClose (moving_mean , np .zeros (moving_mean .shape ))
149
+ self .assertNotAllClose (moving_variance , np .ones (moving_variance .shape ))
129
150
130
151
if __name__ == '__main__' :
131
152
tf .test .main ()
0 commit comments