Skip to content

Commit

Permalink
Add mixed precision tests for Sequential/subclassed models.
Browse files Browse the repository at this point in the history
I only added the "run_with_all_model_types" annotation to one test so that the tests don't take too long to run.

PiperOrigin-RevId: 258640310
  • Loading branch information
reedwm authored and tensorflower-gardener committed Jul 17, 2019
1 parent 3b1f369 commit 054aabf
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
1 change: 1 addition & 0 deletions tensorflow/python/keras/mixed_precision/experimental/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,6 @@ cuda_py_test(
"//tensorflow/python/distribute:one_device_strategy",
"//tensorflow/python/keras",
],
shard_count = 4,
xla_enable_strict_auto_jit = True,
)
32 changes: 21 additions & 11 deletions tensorflow/python/keras/mixed_precision/experimental/keras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,17 @@ def test_checkpointing_layer_weights(self, strategy_fn):
class KerasModelTest(keras_parameterized.TestCase):
"""Test mixed precision with Keras models."""

def _is_strategy_supported(self, strategy_fn):
def _is_strategy_supported(self, strategy_fn, check_model_type=False):
if (strategy_fn != default_strategy_fn and
testing_utils.should_run_eagerly()):
# Distribution strategies do not support running with `run_eagerly=True`
# in Keras Models.
(testing_utils.should_run_eagerly() or
(check_model_type and testing_utils.get_model_type() == 'subclass'))):
# Distribution strategies do not support subclassed models or running with
# `run_eagerly=True`.
return False
else:
return True

@keras_parameterized.run_with_all_model_types
@keras_parameterized.run_all_keras_modes
@parameterized.named_parameters({
'testcase_name': 'base',
Expand All @@ -361,17 +363,26 @@ def _is_strategy_supported(self, strategy_fn):
})
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False,
cloning=True):
if not self._is_strategy_supported(strategy_fn):
if not self._is_strategy_supported(strategy_fn, check_model_type=True):
return
regularizer = IdentityRegularizer() if use_regularizer else None
with strategy_fn().scope():
with policy.policy_scope('infer_float32_vars'):
x = layers.Input(shape=(1,), batch_size=2, dtype=dtypes.float16)
layer_list = []
if testing_utils.get_model_type() == 'subclass':
# Subclassed models do not have an Input layer, so the model does not
# cast inputs to the Input layer's dtype. Therefore, we need to
# manually insert a float16 cast.
cast_f16_layer = layers.Lambda(lambda x: math_ops.cast(x, 'float16'),
input_shape=(1,))
layer_list.append(cast_f16_layer)
layer = AddLayer(assert_type=dtypes.float16, use_operator=use_operator,
regularizer=regularizer)
y = layer(x)
y = math_ops.cast(y, dtypes.float32)
model = models.Model(inputs=x, outputs=y)
regularizer=regularizer, input_shape=(1,))
cast_f32_layer = layers.Lambda(lambda x: math_ops.cast(x, 'float32'))
layer_list += [layer, cast_f32_layer]
model = testing_utils.get_model_from_layers(layer_list,
input_shape=(1,),
input_dtype=dtypes.float16)

def loss_fn(y_true, y_pred):
del y_true
Expand All @@ -388,7 +399,6 @@ def loss_fn(y_true, y_pred):
run_eagerly=testing_utils.should_run_eagerly(),
run_distributed=testing_utils.should_run_distributed())

self.assertEqual(backend.eval(layer.v), 1)
x = np.ones((2, 1))
y = np.ones((2, 1))
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
Expand Down

0 comments on commit 054aabf

Please sign in to comment.