Skip to content

Commit

Permalink
Fixing build_targets in SetCriterion (#143)
Browse files Browse the repository at this point in the history
* Adding initial utilities for SetCriterion

* Adding build_targets and compute_loss for SetCriterion

* Remove unused codes

* Fixed num_targets in ComputeLoss

* Rename to pred_logits in SetCriterion

* Abstracting the encode_single in BoxCoder

* Fix Anchors in SetCriterion

* Rename variables

* Remove BalancedPositiveNegativeSampler in unit-test

* Fixing compute_loss

* Fixing test_criterion

* Fixing jit annotations

* Enable unittest for training

* Rename variables to targets_with_gain

* Fixing jit annotations

* Fixing jit annotations for PT1.7
  • Loading branch information
zhiqwang authored Sep 9, 2021
1 parent cc2bd50 commit b6d8e00
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 274 deletions.
2 changes: 0 additions & 2 deletions test/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import pytest
from pathlib import Path

import torch
Expand Down Expand Up @@ -75,7 +74,6 @@ def test_train_with_vanilla_module():
assert isinstance(out["objectness"], Tensor)


@pytest.mark.skip("Currently it is not well supported.")
def test_training_step():
# Setup the DataModule
data_path = 'data-bin'
Expand Down
8 changes: 3 additions & 5 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,19 +176,17 @@ def test_postprocessors(self):

def test_criterion(self):
N, H, W = 4, 640, 640
anchor_generator = self._init_test_anchor_generator()
feature_maps = self._get_feature_maps(N, H, W)
head_outputs = self._get_head_outputs(N, H, W)
anchors_tuple = anchor_generator(feature_maps)

targets = torch.tensor([
[0.0000, 7.0000, 0.0714, 0.3749, 0.0760, 0.0654],
[0.0000, 1.0000, 0.1027, 0.4402, 0.2053, 0.1920],
[1.0000, 5.0000, 0.4720, 0.6720, 0.3280, 0.1760],
[3.0000, 3.0000, 0.6305, 0.3290, 0.3274, 0.2270],
])
criterion = SetCriterion(iou_thresh=0.5)
out = criterion(targets, head_outputs, anchors_tuple)
criterion = SetCriterion(self.num_anchors, self.strides,
self.anchor_grids, self.num_classes)
out = criterion(targets, head_outputs)
self.assertIsInstance(out, Dict)
self.assertIsInstance(out['cls_logits'], Tensor)
self.assertIsInstance(out['bbox_regression'], Tensor)
Expand Down
14 changes: 0 additions & 14 deletions test/test_models_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,6 @@
import torch

from yolort.models.transform import YOLOTransform, NestedTensor
from yolort.models._utils import BalancedPositiveNegativeSampler


def test_balanced_positive_negative_sampler():
sampler = BalancedPositiveNegativeSampler(4, 0.25)
# keep all 6 negatives first, then add 3 positives, last two are ignore
matched_idxs = [torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1])]
pos, neg = sampler(matched_idxs)
# we know the number of elements that should be sampled for the positive (1)
# and the negative (3), and their location. Let's make sure that they are there
assert pos[0].sum() == 1
assert pos[0][6:9].sum() == 1
assert neg[0].sum() == 3
assert neg[0][0:6].sum() == 3


def test_yolo_transform():
Expand Down
Loading

0 comments on commit b6d8e00

Please sign in to comment.