Closed
Description
Describe the bug
hi, I copy code to use in my own project, and I find a Issuse when I use cross_entropy_loss.
In line 121
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label.max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze()
Should ' label.max() <= 1' mask out ignore_index? Since the ignore_index often set as 255.
Bug fix
if pred.size(1) == 1:
# For binary class segmentation, the shape of pred is
# [N, 1, H, W] and that of label is [N, H, W].
assert label[label != ignore_index].max() <= 1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'
pred = pred.squeeze()
Metadata
Assignees
Labels
No labels