diff --git a/paddleseg/models/losses/cross_entropy_loss.py b/paddleseg/models/losses/cross_entropy_loss.py index 5dc94202a7..fc0fab302d 100644 --- a/paddleseg/models/losses/cross_entropy_loss.py +++ b/paddleseg/models/losses/cross_entropy_loss.py @@ -48,11 +48,8 @@ def __init__(self, self.data_format = data_format if weight is not None: self.weight = paddle.to_tensor(weight, dtype='float32') - long_weight = weight + [0] * (256 - len(weight)) - self.long_weight = paddle.to_tensor(long_weight, dtype='float32') else: self.weight = None - self.long_weight = None def forward(self, logit, label, semantic_weights=None): """ @@ -82,12 +79,13 @@ def forward(self, logit, label, semantic_weights=None): label = label.astype('int64') # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed. + # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version. loss = F.cross_entropy( logit, label, ignore_index=self.ignore_index, reduction='none', - weight=self.long_weight) + weight=self.weight) return self._post_process_loss(logit, label, semantic_weights, loss)