Skip to content

Commit 0845aaa

Browse files
authored
[BERT/TF] bug fix in beta bias correction terms (NVIDIA#395)
1 parent ce73b32 commit 0845aaa

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

TensorFlow/LanguageModeling/BERT/optimization.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -301,12 +301,12 @@ def __init__(self,
301301
self.beta_2 = beta_2
302302
self.epsilon = epsilon
303303
self.exclude_from_weight_decay = exclude_from_weight_decay
304-
self.steps = 0
305304

306-
def apply_gradients(self, grads_and_vars, global_step=None, name=None,
305+
def apply_gradients(self, grads_and_vars, global_step, name=None,
307306
manual_fp16=False):
308307
"""See base class."""
309308
assignments = []
309+
steps = tf.cast(global_step, tf.float32)
310310
for (grad, param) in grads_and_vars:
311311
if grad is None or param is None:
312312
continue
@@ -343,9 +343,8 @@ def apply_gradients(self, grads_and_vars, global_step=None, name=None,
343343
tf.multiply(self.beta_2, v) + tf.multiply(1.0 - self.beta_2,
344344
tf.square(grad)))
345345

346-
self.steps += 1
347-
beta1_correction = (1 - self.beta_1 ** self.steps)
348-
beta2_correction = (1 - self.beta_2 ** self.steps)
346+
beta1_correction = (1 - self.beta_1 ** steps)
347+
beta2_correction = (1 - self.beta_2 ** steps)
349348

350349
next_m_unbiased = next_m / beta1_correction
351350
next_v_unbiased = next_v / beta2_correction

0 commit comments

Comments
 (0)