Skip to content

Commit

Permalink
add gfl loss unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu committed Jun 26, 2021
1 parent 65ce4b6 commit a79f4b1
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions tests/test_models/test_loss/test_gfocal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import pytest
import torch

from nanodet.model.loss.gfocal_loss import DistributionFocalLoss, QualityFocalLoss


def test_qfl():
with pytest.raises(AssertionError):
QualityFocalLoss(use_sigmoid=False)

label = torch.randint(low=0, high=7, size=(10,))
score = torch.rand((10,))
pred = torch.rand((10, 7))
target = (label, score)
weight = torch.zeros(10)

loss = QualityFocalLoss()(pred, target, weight)
assert loss == 0.0

loss = QualityFocalLoss()(pred, target, weight, reduction_override="sum")
assert loss == 0.0


def test_dfl():

pred = torch.rand((10, 7))
target = torch.rand((10,))
weight = torch.zeros(10)

loss = DistributionFocalLoss()(pred, target, weight)
assert loss == 0.0

loss = DistributionFocalLoss()(pred, target, weight, reduction_override="sum")
assert loss == 0.0

0 comments on commit a79f4b1

Please sign in to comment.