Skip to content

[BUG] BUG in cross_entropy_loss.py #1525

Closed
@Dawn-bin

Description

Describe the bug
hi, I copy code to use in my own project, and I find a Issuse when I use cross_entropy_loss.
In line 121

if pred.size(1) == 1:
        # For binary class segmentation, the shape of pred is
        # [N, 1, H, W] and that of label is [N, H, W].
        assert label.max() <= 1, \
            'For pred with shape [N, 1, H, W], its label must have at ' \
            'most 2 classes'
        pred = pred.squeeze()

Should ' label.max() <= 1' mask out ignore_index? Since the ignore_index often set as 255.

Bug fix

if pred.size(1) == 1:
        # For binary class segmentation, the shape of pred is
        # [N, 1, H, W] and that of label is [N, H, W].
        assert label[label != ignore_index].max() <= 1, \
            'For pred with shape [N, 1, H, W], its label must have at ' \
            'most 2 classes'
        pred = pred.squeeze()

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions