From 3b486c44640a3d72b85f72a5c0197b0d489b4117 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 16 Apr 2019 12:47:27 -0700 Subject: [PATCH] Apply changes to v2 loss scale -Also update docstring and test --- .../python/keras/mixed_precision/experimental/loss_scale.py | 5 +++-- .../keras/mixed_precision/experimental/loss_scale_test.py | 5 +++++ tensorflow/python/training/experimental/loss_scale.py | 3 ++- tensorflow/python/training/experimental/loss_scale_test.py | 4 ++-- 4 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py index e72983ee491..d8e173a07f6 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale.py @@ -79,7 +79,8 @@ def update(self, grads): Args: grads: A list of unscaled gradients, each which is the gradient of the loss with respect to a weight. The gradients should have already been - divided by the loss scale being before passed to this function. + divided by the loss scale being before passed to this function. 'None' + gradients are accepted, and should be ignored. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss @@ -183,7 +184,7 @@ def get_config(self): def _is_all_finite(grads): """Returns a scalar boolean tensor indicating if all gradients are finite.""" is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g)) - for g in grads] + for g in grads if g is not None] return math_ops.reduce_all(is_finite_per_grad) diff --git a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py index f8dee5203bb..25622e2fff8 100644 --- a/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py +++ b/tensorflow/python/keras/mixed_precision/experimental/loss_scale_test.py @@ -266,6 +266,11 @@ def test_serialization(self): self.assertEqual(loss_scale.increment_period, 2) self.assertEqual(loss_scale.multiplier, 3) + @test_util.run_in_graph_and_eager_modes + def test_update_with_none_gradients(self): + loss_scale = loss_scale_module.DynamicLossScale() + loss_scale.update([None]) + @test_util.run_in_graph_and_eager_modes def test_get(self): scalar = loss_scale_module.get('dynamic') diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py index 5b6262c0edb..a67823347bb 100644 --- a/tensorflow/python/training/experimental/loss_scale.py +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -83,7 +83,8 @@ def update(self, grads): Args: grads: A list of unscaled gradients, each which is the gradient of the loss with respect to a weight. The gradients should have already been - divided by the loss scale being before passed to this function. + divided by the loss scale being before passed to this function. 'None' + gradients are accepted, and should be ignored. Returns: update_op: In eager mode, None. In graph mode, an op to update the loss diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py index f832022ecb0..d7d52f0050a 100644 --- a/tensorflow/python/training/experimental/loss_scale_test.py +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -254,8 +254,8 @@ def test_random_mix_good_and_bad_gradients(self, strategy_fn): @test_util.run_in_graph_and_eager_modes def test_update_with_none_gradients(self): - loss_scaler = loss_scale_module.DynamicLossScale() - loss_scaler.update([None]) + loss_scale = loss_scale_module.DynamicLossScale() + loss_scale.update([None]) @test_util.run_in_graph_and_eager_modes def test_get(self):