Skip to content

Check differentiability of custom loss function before training #17753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions keras/engine/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,126 @@ def get_custom_object_name(obj):
return generic_utils.to_snake_case(obj.__class__.__name__)
else: # Unrecognized object.
return None


def verify_object_differentiability(
custom_obj, expected_shapes, is_layer=False
):
"""Verifies if a given object is differentiable.

Args:
custom_obj: Can be a plain function or a class instance.
expected_shapes: A tuple containing the shapes for the inputs to the
instance or function.
is_layer: Boolean indicating whether the custom object is a layer
Raises:
ValueError: If the custom object is not differentiable.
"""

if not _verify_object_differentiability(
custom_obj, expected_shapes, is_layer=is_layer
):
raise ValueError(
f"The provided loss or layer ({custom_obj}) is not differentiable. "
"Training requires differentiable objects. Please review your "
"custom object or consider using standard differentiable objects. "
"You can disable the differentiability check by setting "
"'experimental_check_loss_differentiability=False' in "
"'model.compile()'."
)


def _verify_object_differentiability(
custom_obj, expected_shapes=None, is_layer=False
):
"""Verifies if the loss is differentiable.

Args:
custom_obj: Can be a plain function or a class instance.
expected_shapes: A tuple containing the shapes for the inputs to the
loss.
is_layer: Boolean indicating whether the custom object is a layer.
Returns:
A boolean indicating whether the custom object is differentiable.
"""

def generate_shape_tuples(dim, num_dims):
for i in range(num_dims - 1):
yield (1,) * (i + 1) + (dim,)

if expected_shapes is None:
continue_checking = True
for num_dims in range(1, 8):
# Some losses/layers use indexing, so if we are not provided
# expected shapes, we check for differentiability for a few shapes.
# Start with 1D, then 2D, then 3D, etc.
# (1,1) -- (1,1,1) -- (1,1,1,1) -- (1,1,1,1,1) -- (1,1,1,1,1,1)
if not continue_checking:
break
shape_generator = generate_shape_tuples(1, num_dims)
for shapes in shape_generator:
try:
if is_layer:
differentiable = _check_object_with_shapes(
custom_obj, shapes, is_layer=True
)
else:
differentiable = _check_object_with_shapes(
custom_obj, shapes, is_layer=False
)
if differentiable:
return True
else:
continue_checking = False
break # Stop the inner loop when diff is False
except Exception:
# if there is an issue with the loss, we
# continue to the next shape.
continue
return False
else:
return _check_object_with_shapes(custom_obj, expected_shapes)


def _check_object_with_shapes(custom_obj, expected_shape, is_layer=False):
"""Evaluates the custom_obj for the given shapes using `tf.GradientTape`."""

# Replace None batch dimension with 1.
expected_shape = tuple(1 if dim is None else dim for dim in expected_shape)

predictions = tf.random.uniform(
expected_shape, minval=0, maxval=1, dtype=tf.float32
)
targets = tf.random.uniform(
expected_shape, minval=0, maxval=1, dtype=tf.float32
)

with tf.GradientTape() as tape:
tape.watch(predictions)
if is_layer:
try:
# Layer can have a single input or a list of inputs.
# Generic exception is used here because the layer can raise
# any exception if the input is not valid, and we can not know
# what exception to catch.
output_value = custom_obj(predictions)
except Exception:
for inp_multiply in range(2, 7):
try:
# In case the layer takes multiple inputs, we try
# multiplying the input by a number from 2 to 6.
output_value = custom_obj([predictions] * inp_multiply)
except Exception:
continue
else:
break
else:
output_value = custom_obj(targets, predictions)

gradients = tape.gradient(output_value, predictions)

if gradients is None:
# If `gradients` is None, then the loss is not differentiable.
return False

return True
144 changes: 144 additions & 0 deletions keras/engine/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tensorflow.compat.v2 as tf

from keras import backend
from keras import layers as layers_mod
from keras import losses as losses_mod
from keras import metrics as metrics_mod
from keras.engine import compile_utils
Expand Down Expand Up @@ -883,6 +884,149 @@ def test_duplicated_metric_instance(self):
)


class TestObjectDifferentiabilityFunctions(test_combinations.TestCase):
def test_verify_loss_differentiability(self):
# Test case 1: Differentiable loss function
def differentiable_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))

class DifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))

expected_shapes = (None, 1)
compile_utils.verify_object_differentiability(
differentiable_loss, expected_shapes
)

compile_utils.verify_object_differentiability(
DifferentiableLossClass(), expected_shapes
)

# Test case 2: Non-differentiable loss function
def non_differentiable_loss(y_true, y_pred):
return tf.round(tf.square(y_true - y_pred))

class NonDifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.round(tf.square(y_true - y_pred))

with self.assertRaises(ValueError):
compile_utils.verify_object_differentiability(
non_differentiable_loss, expected_shapes
)

with self.assertRaises(ValueError):
compile_utils.verify_object_differentiability(
NonDifferentiableLossClass(), expected_shapes
)

# Case 3: Non-differentiable custom layer.
class NonDifferentiableLayer(layers_mod.Layer):
def call(self, inputs):
return inputs.numpy()

with self.assertRaises(ValueError):
compile_utils.verify_object_differentiability(
NonDifferentiableLayer(), None, is_layer=True
)

def test__verify_loss_differentiability(self):
# Test case 1: Differentiable loss function
def differentiable_loss(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))

class DifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.abs(tf.square(y_true - y_pred))

expected_shapes = (1, 1)
self.assertTrue(
compile_utils._verify_object_differentiability(
differentiable_loss, expected_shapes
)
)

self.assertTrue(
compile_utils._verify_object_differentiability(
DifferentiableLossClass(), expected_shapes
)
)

# Test case 2: Non-differentiable loss function
def non_differentiable_loss(y_true, y_pred):
return tf.round(tf.square(y_true - y_pred))

class NonDifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.argmax(tf.square(y_true - y_pred))

self.assertFalse(
compile_utils._verify_object_differentiability(
non_differentiable_loss, expected_shapes
)
)

self.assertFalse(
compile_utils._verify_object_differentiability(
NonDifferentiableLossClass(), expected_shapes
)
)

# Case 3: Non-differentiable custom layer.
class NonDifferentiableLayer(layers_mod.Layer):
def call(self, inputs):
return inputs.numpy()

self.assertFalse(
compile_utils._verify_object_differentiability(
NonDifferentiableLayer(), None, is_layer=True
)
)

def test__check_loss_with_shapes(self):
# Test case 1: Differentiable loss function
def differentiable_loss(y_true, y_pred):
return tf.reduce_sum(tf.square(y_true - y_pred))

class DifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.reduce_sum(tf.square(y_true - y_pred))

expected_shape = (1, 1)
self.assertTrue(
compile_utils._check_object_with_shapes(
differentiable_loss, expected_shape
)
)

self.assertTrue(
compile_utils._check_object_with_shapes(
DifferentiableLossClass(), expected_shape
)
)

# Test case 2: Non-differentiable loss function
def non_differentiable_loss(y_true, y_pred):
return tf.round(tf.square(y_true - y_pred))

class NonDifferentiableLossClass(losses_mod.Loss):
def call(self, y_true, y_pred):
return tf.argmax(tf.square(y_true - y_pred))

self.assertFalse(
compile_utils._check_object_with_shapes(
non_differentiable_loss, expected_shape
)
)

self.assertFalse(
compile_utils._check_object_with_shapes(
NonDifferentiableLossClass(), expected_shape
)
)


if __name__ == "__main__":
tf.compat.v1.enable_eager_execution()
tf.test.main()
51 changes: 51 additions & 0 deletions keras/engine/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.engine import data_adapter
from keras.engine import input_layer as input_layer_module
from keras.engine import training_utils
from keras.layers import Layer
from keras.metrics import base_metric
from keras.mixed_precision import loss_scale_optimizer as lso
from keras.optimizers import optimizer
Expand All @@ -51,6 +52,7 @@
from keras.utils import tf_utils
from keras.utils import traceback_utils
from keras.utils import version_utils
from keras.utils.losses_utils import get_keras_losses
from keras.utils.mode_keys import ModeKeys

# isort: off
Expand Down Expand Up @@ -754,6 +756,54 @@ def compile(
self.compiled_loss = compile_utils.LossesContainer(
loss, loss_weights, output_names=self.output_names
)

experimental_check_object_differentiability = kwargs.pop(
"experimental_check_object_differentiability", True
)

user_loss = self.compiled_loss._user_losses
builtin_losses = set(get_keras_losses().keys())

if experimental_check_object_differentiability:
if (
not isinstance(user_loss, str)
and user_loss not in builtin_losses
):
# users may pass "mse" for MeanSquaredError,
# which is an alias for a built-in loss.
input_shape_arg = (
self.input_shape
if hasattr(self, "input_shape")
else None
)

compile_utils.verify_object_differentiability(
custom_obj=user_loss, expected_shapes=input_shape_arg
)

# Check if user has custom layers. We clone them in order not
# to set input_shape on the original layers accidentally.
for layer in self.layers:
if isinstance(layer, Layer):
layer_copy = tf.keras.models.clone_model(layer)
if layer_utils.is_not_from_keras_layers(layer_copy):
compile_utils.verify_object_differentiability(
custom_obj=layer_copy,
expected_shapes=None,
is_layer=True,
)
# Check if it is a model instance.
elif isinstance(layer, Model):
model_copy = tf.keras.models.clone_model(layer)
for sublayer in model_copy.layers:
if layer_utils.is_not_from_keras_layers(sublayer):

compile_utils.verify_object_differentiability(
custom_obj=sublayer,
expected_shapes=None,
is_layer=True,
)

self.compiled_metrics = compile_utils.MetricsContainer(
metrics,
weighted_metrics,
Expand Down Expand Up @@ -3759,6 +3809,7 @@ def _validate_compile(self, optimizer, metrics, **kwargs):

kwargs.pop("cloning", None) # Legacy DistStrat argument, never used.
kwargs.pop("experimental_run_tf_function", None) # Always `True`.
kwargs.pop("experimental_check_object_differentiability", None)
distribute_arg = kwargs.pop("distribute", None)
if distribute_arg is not None:
raise ValueError(
Expand Down
25 changes: 25 additions & 0 deletions keras/utils/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1109,3 +1109,28 @@ def convert_vocab_to_list(vocab):
" Received 0 instead."
)
return vocab_list


def is_not_from_keras_layers(layer):
"""Check if layer is not from keras.layers.

This utility will fail if users edit the keras source code and add their own
layers in the keras.layers namespace.

Args: layer: A keras.layer instance.

Returns: True if the layer is not from keras.layers, False otherwise.
"""
_module = layer.__class__.__module__
# it can be keras.layers, tensorflow.keras.layers, or
# tensorflow.python.keras.layers
is_custom_layer = True
for prefix in [
"keras.layers",
"tensorflow.keras.layers",
"tensorflow.python.keras.layers",
]:
if _module.startswith(prefix):
is_custom_layer = False
break
return is_custom_layer
Loading