Skip to content

Commit

Permalink
fix broadcast shape bug in yolov3 (open-mmlab#7551)
Browse files Browse the repository at this point in the history
  • Loading branch information
hhaAndroid authored and ZwwWayne committed Jul 19, 2022
1 parent ccce2d9 commit 3d77502
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mmdet/models/losses/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,10 @@ def binary_cross_entropy(pred,
# should mask out the ignored elements
valid_mask = ((label >= 0) & (label != ignore_index)).float()
if weight is not None:
weight *= valid_mask
# The inplace writing method will have a mismatched broadcast
# shape error if the weight and valid_mask dimensions
# are inconsistent such as (B,N,1) and (B,N,C).
weight = weight * valid_mask
else:
weight = valid_mask

Expand Down

0 comments on commit 3d77502

Please sign in to comment.