Skip to content

Commit

Permalink
Return scalar losses instead of per-sample means (huggingface#18013)
Browse files Browse the repository at this point in the history
* Return scalar losses instead of per-sample means

* Make loss shape (1,) instead of scalar

* Allow scalar losses in test_loss_computation

* Allow scalar losses in test_loss_computation

* Allow scalar losses in test_loss_computation

* Remove XLA loss function for RAG
  • Loading branch information
Rocketknight1 authored Jul 4, 2022
1 parent 6cb1954 commit 96d833b
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 63 deletions.
11 changes: 4 additions & 7 deletions src/transformers/modeling_tf_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,9 @@ def hf_compute_loss(self, labels, logits):
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only labels that are not equal to -100 affect the loss
loss_mask = tf.cast(labels != -100, dtype=unmasked_loss.dtype)
# Avoid division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
return tf.reshape(reduced_masked_loss, (1,))


class TFQuestionAnsweringLoss:
Expand Down Expand Up @@ -266,11 +264,10 @@ def hf_compute_loss(self, labels, logits):
# are taken into account as loss
loss_mask = tf.cast(labels >= 0, dtype=unmasked_loss.dtype)
# Avoid possible division by zero later
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
# Masked positions will have a loss of NaN because -100 and -1 are not valid labels
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
return tf.reshape(reduced_masked_loss, (1,))


class TFSequenceClassificationLoss:
Expand Down
8 changes: 3 additions & 5 deletions src/transformers/models/albert/modeling_tf_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,18 @@ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
# Avoid division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)

sop_logits = tf.reshape(logits[1], (-1, 2))
# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_sop_loss = loss_fn(y_true=tf.nn.relu(labels["sentence_order_label"]), y_pred=sop_logits)
sop_loss_mask = tf.cast(labels["sentence_order_label"] != -100, dtype=unmasked_sop_loss.dtype)

# No reduction because this already has shape (num_samples,)
masked_sop_loss = unmasked_sop_loss * sop_loss_mask
reduced_masked_sop_loss = tf.reduce_sum(masked_sop_loss) / tf.reduce_sum(sop_loss_mask)

return reduced_masked_lm_loss + masked_sop_loss
return tf.reshape(reduced_masked_lm_loss + reduced_masked_sop_loss, (1,))


class TFAlbertEmbeddings(tf.keras.layers.Layer):
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/bert/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,17 @@ def hf_compute_loss(self, labels: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
# make sure only labels that are not equal to -100
# are taken into account for the loss computation
lm_loss_mask = tf.cast(labels["labels"] != -100, dtype=unmasked_lm_losses.dtype)
# Avoid potential division by zero later
lm_loss_denominator = tf.math.maximum(tf.cast(1, lm_loss_mask.dtype), tf.reduce_sum(lm_loss_mask, axis=1))
masked_lm_losses = unmasked_lm_losses * lm_loss_mask
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses, axis=1) / lm_loss_denominator
reduced_masked_lm_loss = tf.reduce_sum(masked_lm_losses) / tf.reduce_sum(lm_loss_mask)

# Clip negative labels to zero here to avoid NaNs and errors - those positions will get masked later anyway
unmasked_ns_loss = loss_fn(y_true=tf.nn.relu(labels["next_sentence_label"]), y_pred=logits[1])
ns_loss_mask = tf.cast(labels["next_sentence_label"] != -100, dtype=unmasked_ns_loss.dtype)
# Just zero out samples where label is -100, no reduction
masked_ns_loss = unmasked_ns_loss * ns_loss_mask

return reduced_masked_lm_loss + masked_ns_loss
reduced_masked_ns_loss = tf.reduce_sum(masked_ns_loss) / tf.reduce_sum(ns_loss_mask)

return tf.reshape(reduced_masked_lm_loss + reduced_masked_ns_loss, (1,))


class TFBertEmbeddings(tf.keras.layers.Layer):
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/models/led/modeling_tf_led.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,6 @@ def hf_compute_loss(self, labels, logits):
unmasked_loss = loss_fn(tf.nn.relu(labels), logits)
# make sure only non-padding labels affect the loss
loss_mask = tf.cast(labels != self.config.pad_token_id, dtype=unmasked_loss.dtype)
loss_denominator = tf.math.maximum(tf.cast(1, loss_mask.dtype), tf.reduce_sum(loss_mask, axis=1))
masked_loss = unmasked_loss * loss_mask
reduced_masked_loss = tf.reduce_sum(masked_loss, axis=1) / loss_denominator
return reduced_masked_loss
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(loss_mask)
return tf.reshape(reduced_masked_loss, (1,))
55 changes: 19 additions & 36 deletions src/transformers/models/rag/modeling_tf_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,46 +1333,29 @@ def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0
# Adopted modeling_tf_bart + add smooth_loss to match with pytorch version
def hf_compute_loss(self, labels, y_pred, smooth_epsilon=0.0, from_logits=True, reduce_loss=False):
"""CrossEntropyLoss that ignores pad tokens"""
if self.config.tf_legacy_loss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.SUM,
)

if from_logits is False: # convert to logits
eps = 1e-9
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
y_pred = tf.math.log(y_pred)

logits = y_pred
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)

reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
nll_loss = loss_fn(labels, reduced_logits)

smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
eps_i = smooth_epsilon / reduced_logits.shape[-1]

loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss

return loss

# Matt: As written, this loss is not XLA-compatible, but it's doing some very weird things
# and I don't feel comfortable converting it.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=from_logits,
reduction=tf.keras.losses.Reduction.NONE,
from_logits=True,
reduction=tf.keras.losses.Reduction.SUM,
)

unmasked_loss = loss_fn(labels, y_pred)
loss_mask = labels != self.config.generator.pad_token_id
nll_loss = tf.reduce_sum(unmasked_loss * loss_mask)
if from_logits is False: # convert to logits
eps = 1e-9
y_pred = tf.clip_by_value(y_pred, clip_value_min=eps, clip_value_max=1 - eps)
y_pred = tf.math.log(y_pred)

logits = y_pred
melted_labels = tf.reshape(labels, (-1,))
active_loss = tf.not_equal(melted_labels, self.config.generator.pad_token_id)

reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, logits.shape[2])), active_loss)
labels = tf.boolean_mask(melted_labels, active_loss)
nll_loss = loss_fn(labels, reduced_logits)

# Matt: This makes no sense to me, but I'm just copying the old loss in XLA-compatible form
smooth_loss = -tf.reduce_sum(y_pred * tf.expand_dims(labels, -1), axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss)
eps_i = smooth_epsilon / y_pred.shape[-1]
smooth_loss = -tf.reduce_sum(reduced_logits, axis=-1)
smooth_loss = tf.reduce_sum(smooth_loss) # sum and squeeze like torch
eps_i = smooth_epsilon / reduced_logits.shape[-1]

loss = (1.0 - smooth_epsilon) * nll_loss + eps_i * smooth_loss

Expand Down
6 changes: 3 additions & 3 deletions tests/models/xlnet/test_modeling_tf_xlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,12 @@ def test_loss_computation(self):
input_ids = prepared_for_class.pop(input_name)

loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
Expand Down Expand Up @@ -453,7 +453,7 @@ def test_loss_computation(self):
# Send to model
loss = model(tuple_input[:-1])[0]

self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])


@require_tf
Expand Down
8 changes: 4 additions & 4 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,7 +1294,7 @@ def test_loss_computation(self):
model_input = prepared_for_class.pop(input_name)

loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

# Test that model correctly compute the loss when we mask some positions
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
Expand All @@ -1307,13 +1307,13 @@ def test_loss_computation(self):
labels[0] = -100
prepared_for_class["labels"] = tf.convert_to_tensor(labels)
loss = model(model_input, **prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
self.assertTrue(not np.any(np.isnan(loss.numpy())))

# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
Expand Down Expand Up @@ -1344,7 +1344,7 @@ def test_loss_computation(self):
# Send to model
loss = model(tuple_input[:-1])[0]

self.assertEqual(loss.shape.as_list(), expected_loss_size)
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])

def test_keras_fit(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 96d833b

Please sign in to comment.