diff --git a/test/test_engine.py b/test/test_engine.py index 246a3762..5ae377f6 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -1,5 +1,4 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -import pytest from pathlib import Path import torch @@ -75,7 +74,6 @@ def test_train_with_vanilla_module(): assert isinstance(out["objectness"], Tensor) -@pytest.mark.skip("Currently it is not well supported.") def test_training_step(): # Setup the DataModule data_path = 'data-bin' diff --git a/test/test_models.py b/test/test_models.py index b78ae8dc..05e513ef 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -176,10 +176,7 @@ def test_postprocessors(self): def test_criterion(self): N, H, W = 4, 640, 640 - anchor_generator = self._init_test_anchor_generator() - feature_maps = self._get_feature_maps(N, H, W) head_outputs = self._get_head_outputs(N, H, W) - anchors_tuple = anchor_generator(feature_maps) targets = torch.tensor([ [0.0000, 7.0000, 0.0714, 0.3749, 0.0760, 0.0654], @@ -187,8 +184,9 @@ def test_criterion(self): [1.0000, 5.0000, 0.4720, 0.6720, 0.3280, 0.1760], [3.0000, 3.0000, 0.6305, 0.3290, 0.3274, 0.2270], ]) - criterion = SetCriterion(iou_thresh=0.5) - out = criterion(targets, head_outputs, anchors_tuple) + criterion = SetCriterion(self.num_anchors, self.strides, + self.anchor_grids, self.num_classes) + out = criterion(targets, head_outputs) self.assertIsInstance(out, Dict) self.assertIsInstance(out['cls_logits'], Tensor) self.assertIsInstance(out['bbox_regression'], Tensor) diff --git a/test/test_models_utils.py b/test/test_models_utils.py index d3514d3f..96292a7c 100644 --- a/test/test_models_utils.py +++ b/test/test_models_utils.py @@ -2,20 +2,6 @@ import torch from yolort.models.transform import YOLOTransform, NestedTensor -from yolort.models._utils import BalancedPositiveNegativeSampler - - -def test_balanced_positive_negative_sampler(): - sampler = BalancedPositiveNegativeSampler(4, 0.25) - # keep all 6 negatives first, then add 3 positives, last two are ignore - matched_idxs = [torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1])] - pos, neg = sampler(matched_idxs) - # we know the number of elements that should be sampled for the positive (1) - # and the negative (3), and their location. Let's make sure that they are there - assert pos[0].sum() == 1 - assert pos[0][6:9].sum() == 1 - assert neg[0].sum() == 3 - assert neg[0][0:6].sum() == 3 def test_yolo_transform(): diff --git a/yolort/models/_utils.py b/yolort/models/_utils.py index e2cb9b22..2c5857a7 100644 --- a/yolort/models/_utils.py +++ b/yolort/models/_utils.py @@ -1,8 +1,7 @@ import math import torch -from torch import Tensor -import torch.nn.functional as F +from torch import nn, Tensor from torchvision.ops import box_convert, box_iou from typing import Tuple, List @@ -18,228 +17,64 @@ def _evaluate_iou(target, pred): return box_iou(target["boxes"], pred["boxes"]).diag().mean() -class BalancedPositiveNegativeSampler: +def encode_single(reference_boxes: Tensor, anchors: Tensor) -> Tensor: """ - This class samples batches, ensuring that they contain a fixed proportion of positives - """ - - def __init__(self, batch_size_per_image: int, positive_fraction: float) -> None: - """ - Args: - batch_size_per_image (int): number of elements to be selected per image - positive_fraction (float): percentace of positive elements per batch - """ - self.batch_size_per_image = batch_size_per_image - self.positive_fraction = positive_fraction - - def __call__(self, matched_idxs: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: - """ - Args: - matched idxs: list of tensors containing -1, 0 or positive values. - Each tensor corresponds to a specific image. - -1 values are ignored, 0 are considered as negatives and > 0 as - positives. - - Returns: - pos_idx (list[tensor]) - neg_idx (list[tensor]) - - Returns two lists of binary masks for each image. - The first list contains the positive elements that were selected, - and the second list the negative example. - """ - pos_idx = [] - neg_idx = [] - for matched_idxs_per_image in matched_idxs: - positive = torch.where(matched_idxs_per_image >= 1)[0] - negative = torch.where(matched_idxs_per_image == 0)[0] - - num_pos = int(self.batch_size_per_image * self.positive_fraction) - # protect against not enough positive examples - num_pos = min(positive.numel(), num_pos) - num_neg = self.batch_size_per_image - num_pos - # protect against not enough negative examples - num_neg = min(negative.numel(), num_neg) - - # randomly select positive and negative examples - perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] - perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] - - pos_idx_per_image = positive[perm1] - neg_idx_per_image = negative[perm2] + Encode a set of anchors with respect to some + reference boxes - # create binary mask from indices - pos_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) - neg_idx_per_image_mask = torch.zeros_like( - matched_idxs_per_image, dtype=torch.uint8 - ) - - pos_idx_per_image_mask[pos_idx_per_image] = 1 - neg_idx_per_image_mask[neg_idx_per_image] = 1 - - pos_idx.append(pos_idx_per_image_mask) - neg_idx.append(neg_idx_per_image_mask) - - return pos_idx, neg_idx - - -class BoxCoder: - """ - This class encodes and decodes a set of bounding boxes into - the representation used for training the regressors. + Args: + reference_boxes (Tensor): reference boxes + anchors_tuple (Tensor): boxes to be encoded """ - def decode_single( - self, - rel_codes: Tensor, - anchors_tuple: Tuple[Tensor, Tensor, Tensor], - ): - """ - From a set of original boxes and encoded relative box offsets, - get the decoded boxes. - - Arguments: - rel_codes (Tensor): encoded boxes - anchors_tupe (Tensor, Tensor, Tensor): reference boxes. - """ + reference_boxes = torch.sigmoid(reference_boxes) - pred_wh = (rel_codes[..., 0:2] * 2. + anchors_tuple[0]) * anchors_tuple[1] # wh - pred_xy = (rel_codes[..., 2:4] * 2) ** 2 * anchors_tuple[2] # xy - pred_boxes = torch.cat([pred_wh, pred_xy], dim=1) - pred_boxes = box_convert(pred_boxes, in_fmt="cxcywh", out_fmt="xyxy") + pred_xy = reference_boxes[:, :2] * 2. - 0.5 + pred_wh = (reference_boxes[:, 2:4] * 2) ** 2 * anchors + pred_boxes = torch.cat((pred_xy, pred_wh), 1) - return pred_boxes + return pred_boxes -class Matcher: +def decode_single( + rel_codes: Tensor, + anchors_tuple: Tuple[Tensor, Tensor, Tensor], +) -> Tensor: """ - This class assigns to each predicted "element" (e.g., a box) a ground-truth - element. Each predicted element will have exactly zero or one matches; each - ground-truth element may be assigned to zero or more predicted elements. + From a set of original boxes and encoded relative box offsets, + get the decoded boxes. - Matching is based on the MxN match_quality_matrix, that characterizes how well - each (ground-truth, predicted)-pair match. For example, if the elements are - boxes, the matrix may contain box IoU overlap values. - - The matcher returns a tensor of size N containing the index of the ground-truth - element m that matches to prediction n. If there is no match, a negative value - is returned. + Arguments: + rel_codes (Tensor): encoded boxes + anchors_tupe (Tensor, Tensor, Tensor): reference boxes. """ - def __init__( - self, - iou_threshold: float, - allow_low_quality_matches: bool = False, - ) -> None: - """ - Args: - iou_threshold (float): quality values greater than or equal to - this value are candidate matches. - allow_low_quality_matches (bool): if True, produce additional matches - for predictions that have only low-quality match candidates. See - set_low_quality_matches_ for more details. - """ - self.BELOW_LOW_THRESHOLD = -1 - self.BETWEEN_THRESHOLDS = -2 - self.iou_threshold = iou_threshold - self.allow_low_quality_matches = allow_low_quality_matches - - def __call__(self, match_quality_matrix): - """ - Args: - match_quality_matrix (Tensor[float]): an MxN tensor, containing the - pairwise quality between M ground-truth elements and N predicted elements. - - Returns: - matches (Tensor[int64]): an N tensor where N[i] is a matched gt in - [0, M - 1] or a negative value indicating that prediction i could not - be matched. - """ - if match_quality_matrix.numel() == 0: - # empty targets or proposals not supported during training - if match_quality_matrix.shape[0] == 0: - raise ValueError( - "No ground-truth boxes available for one of the images " - "during training") - else: - raise ValueError( - "No proposal boxes available for one of the images " - "during training") - - # match_quality_matrix is M (gt) x N (predicted) - # Max over gt elements (dim 0) to find best gt candidate for each prediction - matched_vals, matches = match_quality_matrix.max(dim=0) - if self.allow_low_quality_matches: - all_matches = matches.clone() - else: - all_matches = None - - # Assign candidate matches with low quality to negative (unassigned) values - below_low_threshold = matched_vals < self.iou_threshold - between_thresholds = (matched_vals >= self.iou_threshold) & ( - matched_vals < self.iou_threshold - ) - matches[below_low_threshold] = self.BELOW_LOW_THRESHOLD - matches[between_thresholds] = self.BETWEEN_THRESHOLDS - - if self.allow_low_quality_matches: - assert all_matches is not None - self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) - - # For each gt, find the prediction with which it has the highest quality - _, highest_quality_pred_foreach_gt = match_quality_matrix.max(dim=1) - matches[highest_quality_pred_foreach_gt] = torch.arange(highest_quality_pred_foreach_gt.size(0), - dtype=torch.int64, - device=highest_quality_pred_foreach_gt.device) - return matches + pred_wh = (rel_codes[..., 0:2] * 2. + anchors_tuple[0]) * anchors_tuple[1] # wh + pred_xy = (rel_codes[..., 2:4] * 2) ** 2 * anchors_tuple[2] # xy + pred_boxes = torch.cat([pred_wh, pred_xy], dim=1) + pred_boxes = box_convert(pred_boxes, in_fmt="cxcywh", out_fmt="xyxy") - def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix): - """ - Produce additional matches for predictions that have only low-quality matches. - Specifically, for each ground-truth find the set of predictions that have - maximum overlap with it (including ties); for each prediction in that set, if - it is unmatched, then match it to the ground-truth with which it has the highest - quality value. - """ - # For each gt, find the prediction with which it has highest quality - highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) - # Find highest quality match available, even if it is low, including ties - gt_pred_pairs_of_highest_quality = torch.where( - match_quality_matrix == highest_quality_foreach_gt[:, None] - ) - # Example gt_pred_pairs_of_highest_quality: - # tensor([[ 0, 39796], - # [ 1, 32055], - # [ 1, 32070], - # [ 2, 39190], - # [ 2, 40255], - # [ 3, 40390], - # [ 3, 41455], - # [ 4, 45470], - # [ 5, 45325], - # [ 5, 46390]]) - # Each row is a (gt index, prediction index) - # Note how gt items 1, 2, 3, and 5 each have two ties + return pred_boxes - pred_inds_to_update = gt_pred_pairs_of_highest_quality[1] - matches[pred_inds_to_update] = all_matches[pred_inds_to_update] - -def bbox_ciou(box1, box2, eps: float = 1e-9): - # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 +def bbox_iou(box1: Tensor, box2: Tensor, x1y1x2y2: bool = True, eps: float = 1e-7): + """ + Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 + """ box2 = box2.T # Get the coordinates of bounding boxes - # transform from xywh to xyxy - b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 - b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 - b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 - b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + if x1y1x2y2: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: # transform from xywh to xyxy + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 # Intersection area - inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ - (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * ( + torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) # Union Area w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps @@ -250,23 +85,54 @@ def bbox_ciou(box1, box2, eps: float = 1e-9): cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height - + # Complete IoU https://arxiv.org/abs/1911.08287v1 c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared - # center distance squared - rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center distance squared + # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) with torch.no_grad(): - alpha = v / ((1 + eps) - iou + v) - + alpha = v / (v - iou + (1 + eps)) return iou - (rho2 / c2 + v * alpha) # CIoU -def cls_loss(inputs, targets, pos_weight): - loss = F.binary_cross_entropy_with_logits(inputs, targets, pos_weight=pos_weight) - return loss - - -def obj_loss(inputs, targets, pos_weight): - loss = F.binary_cross_entropy_with_logits(inputs, targets, pos_weight=pos_weight) - return loss +def smooth_binary_cross_entropy(eps: float = 0.1) -> Tuple[float, float]: + # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + # return positive, negative label smoothing binary cross entropy targets + return 1.0 - 0.5 * eps, 0.5 * eps + + +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn(), + # i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super().__init__() + # must be nn.BCEWithLogitsLoss() + self.loss_fcn = loss_fcn + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + # required to apply FL to each element + self.loss_fcn.reduction = 'none' + + def forward(self, pred, logit): + loss = self.loss_fcn(pred, logit) + # p_t = torch.exp(-loss) + # non-zero power for gradient stability + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma + + # TF implementation + # https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = torch.sigmoid(pred) # prob from logits + p_t = logit * pred_prob + (1 - logit) * (1 - pred_prob) + alpha_factor = logit * self.alpha + (1 - logit) * (1 - self.alpha) + modulating_factor = (1.0 - p_t) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 92cf2ced..57fad706 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -2,7 +2,7 @@ import math import torch from torch import nn, Tensor - +import torch.nn.functional as F from torchvision.ops import boxes as box_ops from . import _utils as det_utils @@ -73,20 +73,70 @@ def forward(self, x: List[Tensor]) -> List[Tensor]: class SetCriterion: """ This class computes the loss for YOLOv5. + + Args: + num_anchors (int): The number of anchors. + num_classes (int): The number of output classes of the model. + fl_gamma (float): focal loss gamma (efficientDet default gamma=1.5). Default: 0.0. + box_gain (float): box loss gain. Default: 0.05. + cls_gain (float): class loss gain. Default: 0.5. + cls_pos (float): cls BCELoss positive_weight. Default: 1.0. + obj_gain (float): obj loss gain (scale with pixels). Default: 1.0. + obj_pos (float): obj BCELoss positive_weight. Default: 1.0. + anchor_thresh (float): anchor-multiple threshold. Default: 4.0. + label_smoothing (float): Label smoothing epsilon. Default: 0.0. + auto_balance (bool): Auto balance. Default: False. """ - def __init__(self, iou_thresh: float = 0.5) -> None: - """ - Args: - iou_thresh (float): minimum IoU between the anchor and the GT box so that they can be - considered as positive during training. - """ - self.proposal_matcher = det_utils.Matcher(iou_threshold=iou_thresh) + def __init__( + self, + num_anchors: int, + strides: List[int], + anchor_grids: List[List[float]], + num_classes: int, + fl_gamma: float = 0.0, + box_gain: float = 0.05, + cls_gain: float = 0.5, + cls_pos: float = 1.0, + obj_gain: float = 1.0, + obj_pos: float = 1.0, + anchor_thresh: float = 4.0, + label_smoothing: float = 0.0, + auto_balance: bool = False, + ) -> None: + assert len(strides) == len(anchor_grids) + + self.num_anchors = num_anchors + self.num_classes = num_classes + self.strides = strides + self.anchor_grids = anchor_grids + + self.balance = [4.0, 1.0, 0.4] + self.ssi = 0 # stride 16 index + + self.sort_obj_iou = False + + # Define criteria + self.cls_pos = cls_pos + self.obj_pos = obj_pos + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + # positive, negative BCE targets + smooth_bce = det_utils.smooth_binary_cross_entropy(eps=label_smoothing) + self.smooth_pos = smooth_bce[0] + self.smooth_neg = smooth_bce[1] + + # Parameters for training + self.gr = 1.0 + self.auto_balance = auto_balance + self.box_gain = box_gain + self.cls_gain = cls_gain + self.obj_gain = obj_gain + self.anchor_thresh = anchor_thresh def __call__( self, targets: Tensor, head_outputs: List[Tensor], - anchors_tuple: Tuple[Tensor, Tensor, Tensor], ) -> Dict[str, Tensor]: """ This performs the loss computation. @@ -96,38 +146,145 @@ def __call__( expected keys in each dict depends on the losses applied, see each loss' doc head_outputs (List[Tensor]): dict of tensors, see the output specification of the model for the format - anchors_tuple (Tuple[Tensor, Tensor, Tensor]): Anchor tuple """ - matched_idxs = [] + device = targets.device + anchor_grids = torch.as_tensor(self.anchor_grids, dtype=torch.float32, + device=device).view(self.num_anchors, -1, 2) + strides = torch.as_tensor(self.strides, dtype=torch.float32, + device=device).view(-1, 1, 1) + anchor_grids /= strides - return self.compute_loss(targets, head_outputs, matched_idxs) + target_cls, target_box, indices, anchors = self.build_targets(targets, head_outputs, anchor_grids) - def compute_loss( - self, - targets: Tensor, - head_outputs: List[Tensor], - matched_idxs: List[Tensor], - ): - device = targets.device + pos_weight_cls = torch.as_tensor([self.cls_pos], device=device) + pos_weight_obj = torch.as_tensor([self.obj_pos], device=device) loss_cls = torch.zeros(1, device=device) loss_box = torch.zeros(1, device=device) loss_obj = torch.zeros(1, device=device) + # Computing the losses + for i, pred_logits in enumerate(head_outputs): # layer index, layer predictions + b, a, gj, gi = indices[i] # image, anchor, gridy, gridx + target_obj = torch.zeros_like(pred_logits[..., 0], device=device) # target obj + + num_targets = b.shape[0] # number of targets + if num_targets > 0: + # prediction subset corresponding to targets + pred_logits_subset = pred_logits[b, a, gj, gi] + + # Regression + pred_box = det_utils.encode_single(pred_logits_subset, anchors[i]) + iou = det_utils.bbox_iou(pred_box.T, target_box[i], x1y1x2y2=False) + loss_box += (1.0 - iou).mean() # iou loss + + # Objectness + score_iou = iou.detach().clamp(0).to(dtype=target_obj.dtype) + if self.sort_obj_iou: + sort_id = torch.argsort(score_iou) + b, a, gj, gi = b[sort_id], a[sort_id], gj[sort_id], gi[sort_id] + score_iou = score_iou[sort_id] + target_obj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * score_iou # iou ratio + + # Classification + if self.num_classes > 1: # cls loss (only if multiple classes) + t = torch.full_like(pred_logits_subset[:, 5:], self.smooth_neg, device=device) # targets + t[torch.arange(num_targets), target_cls[i]] = self.smooth_pos + loss_cls += F.binary_cross_entropy_with_logits( + pred_logits_subset[:, 5:], t, pos_weight=pos_weight_cls) + + obji = F.binary_cross_entropy_with_logits( + pred_logits[..., 4], target_obj, pos_weight=pos_weight_obj) + loss_obj += obji * self.balance[i] # obj loss + if self.auto_balance: + self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item() + + if self.auto_balance: + self.balance = [x / self.balance[self.ssi] for x in self.balance] + loss_box *= self.box_gain + loss_obj *= self.obj_gain + loss_cls *= self.cls_gain + return { 'cls_logits': loss_cls, 'bbox_regression': loss_box, 'objectness': loss_obj, } + def build_targets( + self, + targets: Tensor, + head_outputs: List[Tensor], + anchor_grids: Tensor, + ) -> Tuple[List[Tensor], List[Tensor], List[Tuple[Tensor, Tensor, Tensor, Tensor]], List[Tensor]]: + device = targets.device + num_anchors = self.num_anchors + + num_targets = targets.shape[0] + + gain = torch.ones(7, device=device) # normalized to gridspace gain + # same as .repeat_interleave(num_targets) + ai = torch.arange(num_anchors, device=device).float().view(num_anchors, 1).repeat(1, num_targets) + # append anchor indices + targets = torch.cat((targets.repeat(num_anchors, 1, 1), ai[:, :, None]), 2) + + g_bias = 0.5 + offset = torch.tensor([[0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=device).float() * g_bias # offsets + + target_cls, target_box, anch = [], [], [] + indices: List[Tuple[Tensor, Tensor, Tensor, Tensor]] = [] + + for i in range(num_anchors): + anchors = anchor_grids[i] + gain[2:6] = torch.tensor(head_outputs[i].shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + targets_with_gain = targets * gain + if num_targets > 0: + # Matches + r = targets_with_gain[:, :, 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1. / r).max(2)[0] < self.anchor_thresh # compare + # j = wh_iou(anchors, targets_with_gain[:, 4:6]) > model.hyp['iou_t'] + # iou(3, n) = wh_iou(anchors(3, 2), gwh(n, 2)) + targets_with_gain = targets_with_gain[j] # filter + + # Offsets + gxy = targets_with_gain[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + idx_jk = ((gxy % 1. < g_bias) & (gxy > 1.)).T + idx_lm = ((gxi % 1. < g_bias) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(idx_jk[0]), idx_jk[0], idx_jk[1], idx_lm[0], idx_lm[1])) + targets_with_gain = targets_with_gain.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + offset[:, None])[j] + else: + targets_with_gain = targets[0] + offsets = torch.tensor(0, device=device) + + # Define + idx_bc = targets_with_gain[:, :2].long().T # image, class + gxy = targets_with_gain[:, 2:4] # grid xy + gwh = targets_with_gain[:, 4:6] # grid wh + gij = (gxy - offsets).long() + idx_gij = gij.T # grid xy indices + + # Append + a = targets_with_gain[:, 6].long() # anchor indices + # image, anchor, grid indices + indices.append((idx_bc[0], a, idx_gij[1].clamp_(0, gain[3] - 1), idx_gij[0].clamp_(0, gain[2] - 1))) + target_box.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors[a]) # anchors + target_cls.append(idx_bc[1]) # class + + return target_cls, target_box, indices, anch + class PostProcess(nn.Module): """ Performs Non-Maximum Suppression (NMS) on inference results """ - __annotations__ = { - 'box_coder': det_utils.BoxCoder, - } def __init__( self, score_thresh: float, @@ -141,10 +298,9 @@ def __init__( detections_per_img (int): Number of best detections to keep after NMS. """ super().__init__() - self.box_coder = det_utils.BoxCoder() self.score_thresh = score_thresh self.nms_thresh = nms_thresh - self.detections_per_img = detections_per_img # maximum number of detections per image + self.detections_per_img = detections_per_img def forward( self, @@ -180,7 +336,7 @@ def forward( # box_conf x class_conf, w/ shape: num_anchors x num_classes scores = pred_logits[:, 5:] * pred_logits[:, 4:5] - boxes = self.box_coder.decode_single(pred_logits[:, :4], anchors_tuple) + boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple) # remove low scoring boxes inds, labels = torch.where(scores > self.score_thresh) diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index e5410042..76817f57 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -57,7 +57,6 @@ def __init__( anchor_generator: Optional[nn.Module] = None, head: Optional[nn.Module] = None, # Training parameter - iou_thresh: float = 0.5, criterion: Optional[Callable[..., Dict[str, Tensor]]] = None, # Post Process parameter score_thresh: float = 0.005, @@ -87,16 +86,13 @@ def __init__( self.anchor_generator = anchor_generator if criterion is None: - criterion = SetCriterion(iou_thresh) + criterion = SetCriterion(anchor_generator.num_anchors, anchor_generator.strides, + anchor_generator.anchor_grids, num_classes) self.compute_loss = criterion if head is None: - head = YOLOHead( - backbone.out_channels, - anchor_generator.num_anchors, - anchor_generator.strides, - num_classes, - ) + head = YOLOHead(backbone.out_channels, anchor_generator.num_anchors, + anchor_generator.strides, num_classes) self.head = head if post_process is None: @@ -123,7 +119,7 @@ def forward( targets: Optional[Tensor] = None, ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: """ - Arguments: + Args: samples (NestedTensor): Expects a NestedTensor, which consists of: - samples.tensor: batched images, of shape [batch_size x 3 x H x W] targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) @@ -148,7 +144,7 @@ def forward( if self.training: assert targets is not None # compute the losses - losses = self.compute_loss(targets, head_outputs, anchors_tuple) + losses = self.compute_loss(targets, head_outputs) else: # compute the detections detections = self.post_process(head_outputs, anchors_tuple)