Skip to content

Commit

Permalink
fix binary classification for tf segformer huggingface#2
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolaJovisic committed Aug 22, 2023
1 parent f7b881d commit 75e4f04
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions src/transformers/models/segformer/modeling_tf_segformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,24 @@ def hf_compute_loss(self, logits, labels):

upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear")
# compute weighted loss
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")

if self.config.num_labels > 1:
loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = tf.cast(
(labels >= 0) & (labels != self.config.semantic_loss_ignore_index),
dtype=tf.float32
)
loss_fct = tf.keras.losses.BinaryCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE
)
loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last
loss = tf.reduce_mean(loss * valid_mask)
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")


def masked_loss(real, pred):
unmasked_loss = loss_fct(real, pred)
Expand Down Expand Up @@ -829,20 +846,14 @@ def call(

loss = None
if labels is not None:
# upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if self.config.num_labels > 1:
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels)
elif self.config.num_labels == 1:
valid_mask = ((labels >= 0) & (labels != self.config.semantic_loss_ignore_index)).float()
loss_fct = BCEWithLogitsLoss(reduction="none")
loss = loss_fct(upsampled_logits.squeeze(1), labels.float())
loss = (loss * valid_mask).mean()
if not self.config.num_labels > 1:
raise ValueError("The number of labels should be greater than one")
else:
raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}")
loss = self.hf_compute_loss(logits=logits, labels=labels)

# make logits of shape (batch_size, num_labels, height, width) to
# keep them consistent across APIs
logits = tf.transpose(logits, perm=[0, 3, 1, 2])

if not return_dict:
if output_hidden_states:
Expand Down

0 comments on commit 75e4f04

Please sign in to comment.