From 5b097ec098654aca8fe38cd74966ad3223cf2a9c Mon Sep 17 00:00:00 2001 From: nikola-jovisic Date: Tue, 22 Aug 2023 16:32:44 +0200 Subject: [PATCH] Revert "fix #4" This reverts commit 0b534e62d03db5ef74f77b61837e0561a1fc129a. --- .../models/segformer/modeling_tf_segformer.py | 24 ++++--------------- 1 file changed, 5 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/segformer/modeling_tf_segformer.py b/src/transformers/models/segformer/modeling_tf_segformer.py index 9464180c2bbfd0..632382f95ed0a7 100644 --- a/src/transformers/models/segformer/modeling_tf_segformer.py +++ b/src/transformers/models/segformer/modeling_tf_segformer.py @@ -759,24 +759,7 @@ def hf_compute_loss(self, logits, labels): upsampled_logits = tf.image.resize(logits, size=label_interp_shape, method="bilinear") # compute weighted loss - - if self.config.num_labels > 1: - loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") - loss = loss_fct(upsampled_logits, labels) - elif self.config.num_labels == 1: - valid_mask = tf.cast( - (labels >= 0) & (labels != self.config.semantic_loss_ignore_index), - dtype=tf.float32 - ) - loss_fct = tf.keras.losses.BinaryCrossentropy( - from_logits=True, - reduction=tf.keras.losses.Reduction.NONE - ) - loss = loss_fct(labels, upsampled_logits[:, 0]) # Assuming channel dimension is last - loss = tf.reduce_mean(loss * valid_mask) - else: - raise ValueError(f"Number of labels should be >=0: {self.config.num_labels}") - + loss_fct = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none") def masked_loss(real, pred): unmasked_loss = loss_fct(real, pred) @@ -846,7 +829,10 @@ def call( loss = None if labels is not None: - loss = self.hf_compute_loss(logits=logits, labels=labels) + if not self.config.num_labels > 1: + raise ValueError("The number of labels should be greater than one") + else: + loss = self.hf_compute_loss(logits=logits, labels=labels) # make logits of shape (batch_size, num_labels, height, width) to # keep them consistent across APIs