Skip to content

Commit

Permalink
Apply changes to v2 loss scale
Browse files Browse the repository at this point in the history
-Also update docstring and test
  • Loading branch information
MattConley committed Apr 16, 2019
1 parent 4896cc2 commit 3b486c4
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def update(self, grads):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and should be ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
Expand Down Expand Up @@ -183,7 +184,7 @@ def get_config(self):
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g))
for g in grads]
for g in grads if g is not None]
return math_ops.reduce_all(is_finite_per_grad)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,11 @@ def test_serialization(self):
self.assertEqual(loss_scale.increment_period, 2)
self.assertEqual(loss_scale.multiplier, 3)

@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])

@test_util.run_in_graph_and_eager_modes
def test_get(self):
scalar = loss_scale_module.get('dynamic')
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/python/training/experimental/loss_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def update(self, grads):
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
divided by the loss scale being before passed to this function. 'None'
gradients are accepted, and should be ignored.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/python/training/experimental/loss_scale_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,8 @@ def test_random_mix_good_and_bad_gradients(self, strategy_fn):

@test_util.run_in_graph_and_eager_modes
def test_update_with_none_gradients(self):
loss_scaler = loss_scale_module.DynamicLossScale()
loss_scaler.update([None])
loss_scale = loss_scale_module.DynamicLossScale()
loss_scale.update([None])

@test_util.run_in_graph_and_eager_modes
def test_get(self):
Expand Down

0 comments on commit 3b486c4

Please sign in to comment.