-
Notifications
You must be signed in to change notification settings - Fork 24
/
flipGradientTF.py
43 lines (33 loc) · 1.29 KB
/
flipGradientTF.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import tensorflow as tf
from keras.engine import Layer
import keras.backend as K
def reverse_gradient(X, hp_lambda):
'''Flips the sign of the incoming gradient during training.'''
try:
reverse_gradient.num_calls += 1
except AttributeError:
reverse_gradient.num_calls = 1
grad_name = "GradientReversal%d" % reverse_gradient.num_calls
@tf.RegisterGradient(grad_name)
def _flip_gradients(op, grad):
return [tf.negative(grad) * hp_lambda]
g = K.get_session().graph
with g.gradient_override_map({'Identity': grad_name}):
y = tf.identity(X)
return y
class GradientReversal(Layer):
'''Flip the sign of gradient during training.'''
def __init__(self, hp_lambda, **kwargs):
super(GradientReversal, self).__init__(**kwargs)
self.supports_masking = False
self.hp_lambda = hp_lambda
def build(self, input_shape):
self.trainable_weights = []
def call(self, x, mask=None):
return reverse_gradient(x, self.hp_lambda)
def get_output_shape_for(self, input_shape):
return input_shape
def get_config(self):
config = {'hp_lambda': self.hp_lambda}
base_config = super(GradientReversal, self).get_config()
return dict(list(base_config.items()) + list(config.items()))