Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolaJovisic committed Aug 22, 2023
1 parent 0b534e6 commit 15b5160
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,32 +762,17 @@ def hf_compute_loss(self, logits, labels):

if self.config.num_labels > 1:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = tf.cast(
(labels >= 0) & (labels != self.config.semantic_loss_ignore_index),
dtype=tf.float32
)
loss_fct = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last
loss = tf.reduce_mean(loss * valid_mask)
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")


def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
mask = tf.cast(real != self.config.semantic_loss_ignore_index, dtype=unmasked_loss.dtype)
masked_loss = unmasked_loss * mask
# Reduction strategy in the similar spirit with
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_tf_utils.py#L210
reduced_masked_loss = tf.reduce_sum(masked_loss) / tf.reduce_sum(mask)
return tf.reshape(reduced_masked_loss, (1,))

return masked_loss(labels, upsampled_logits)
return loss_fct(labels, upsampled_logits)

@unpack_inputs
@add_start_docstrings_to_model_forward(SEGFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
Expand Down Expand Up @@ -846,7 +831,24 @@ def call(

loss = None
if labels is not None:
loss = self.hf_compute_loss(logits=logits, labels=labels)
# upsample logits to the images' original size
# `labels` is of shape (batch_size, height, width)
label_interp_shape = shape_list(labels)[1:]

upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss

if self.config.num_labels > 1:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
loss_fct = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")

# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
Expand Down

0 comments on commit 15b5160

Please sign in to comment.