-
Notifications
You must be signed in to change notification settings - Fork 9.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
YOLACT #3456
YOLACT #3456
Changes from 1 commit
9dfd81f
7599d41
2bb0915
98952e0
1a25228
335f0a5
9ef690b
bdbd584
371ceb9
9c3b569
1dc7f82
b0719de
b21ff78
7be568d
eceff7a
9ec3ae9
a1ba224
001968e
8cc50bb
19e5440
dc0ddda
11a9d01
750eff2
dca0902
2604f87
b15b24a
af5f7c2
300aca9
9deb91a
779adf0
bcc6c08
35ec920
dea1329
8331a8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,11 +64,11 @@ def fast_nms(multi_bboxes, | |
score_thr, | ||
nms_cfg, | ||
max_num=-1): | ||
"""Fast NMS in YOLACT. | ||
"""Fast NMS in `YOLACT <https://arxiv.org/abs/1904.02689>`_. | ||
hellock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) | ||
multi_scores (Tensor): shape (n, #class), where the 0th column | ||
multi_scores (Tensor): shape (n, #class), where the last column | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the shape be (n, #class+1)? |
||
contains scores of the background class, but this will be ignored. | ||
multi_coeffs (Tensor): shape (n, #class*coeffs_dim). | ||
score_thr (float): bbox threshold, bboxes with scores lower than it | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,10 +14,21 @@ | |
|
||
@HEADS.register_module() | ||
class YolactHead(AnchorHead): | ||
"""An anchor-based head used in YOLACT [1]_. | ||
|
||
References: | ||
.. [1] https://arxiv.org/pdf/1904.02689.pdf | ||
"""YOLACT box head used in https://arxiv.org/abs/1904.02689. | ||
|
||
Note that YOLACT head is a light version of RetinaNet head. | ||
Four differences are described as follows: | ||
hellock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
1. YOLACT box head has three-times fewer anchors. | ||
2. YOLACT box head shares the convs for box and cls branches. | ||
3. YOLACT box head uses OHEM instead of Focal loss. | ||
4. YOLACT box head predicts a set of mask coefficients for each box. | ||
|
||
Args: | ||
num_head_convs (int): Number of the conv layers shared by | ||
box and cls branches. | ||
num_protos (int): Number of the mask coefficients. | ||
use_OHEM (bool): If true, `loss_single_OHEM` will be used for | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. use_ohem |
||
cls loss calculation. If false, `loss_single` will be used. | ||
""" | ||
hellock marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __init__(self, | ||
|
@@ -111,6 +122,12 @@ def loss(self, | |
gt_labels, | ||
img_metas, | ||
gt_bboxes_ignore=None): | ||
"""A combination of the func:`AnchorHead.loss` and func:`SSDHead.loss`. | ||
|
||
When self.use_OHEM == True, it functions like `SSDHead.loss`, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rst syntax uses double `: When ``self.use_OHEM == True``, it functions like ``SSDHead.loss`` |
||
otherwise, it follows `AnchorHead.loss`. Besides, it additionally | ||
returns `sampling_results`. | ||
""" | ||
featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] | ||
assert len(featmap_sizes) == self.anchor_generator.num_levels | ||
|
||
|
@@ -200,32 +217,10 @@ def loss(self, | |
return dict( | ||
loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results | ||
|
||
def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights, | ||
bbox_targets, bbox_weights, num_total_samples): | ||
# classification loss | ||
labels = labels.reshape(-1) | ||
label_weights = label_weights.reshape(-1) | ||
cls_score = cls_score.permute(0, 2, 3, | ||
1).reshape(-1, self.cls_out_channels) | ||
loss_cls = self.loss_cls( | ||
cls_score, labels, label_weights, avg_factor=num_total_samples) | ||
# regression loss | ||
bbox_targets = bbox_targets.reshape(-1, 4) | ||
bbox_weights = bbox_weights.reshape(-1, 4) | ||
bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4) | ||
if self.reg_decoded_bbox: | ||
anchors = anchors.reshape(-1, 4) | ||
bbox_pred = self.bbox_coder.decode(anchors, bbox_pred) | ||
loss_bbox = self.loss_bbox( | ||
bbox_pred, | ||
bbox_targets, | ||
bbox_weights, | ||
avg_factor=num_total_samples) | ||
return loss_cls, loss_bbox | ||
|
||
def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels, | ||
label_weights, bbox_targets, bbox_weights, | ||
num_total_samples): | ||
""""See func:`SSDHead.loss`.""" | ||
loss_cls_all = F.cross_entropy( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Losses are hard-coded instead of using |
||
cls_score, labels, reduction='none') * label_weights | ||
# FG cat_id: [0, num_classes -1], BG cat_id: num_classes | ||
|
@@ -260,7 +255,8 @@ def get_bboxes(self, | |
img_metas, | ||
cfg=None, | ||
rescale=False): | ||
|
||
""""Similiar to func:`AnchorHead.get_bboxes`, but additionally | ||
processes coeff_preds.""" | ||
assert len(cls_scores) == len(bbox_preds) | ||
num_levels = len(cls_scores) | ||
|
||
|
@@ -298,7 +294,9 @@ def _get_bboxes_single(self, | |
scale_factor, | ||
cfg, | ||
rescale=False): | ||
|
||
""""Similiar to func:`AnchorHead._get_bboxes_single`, but additionally | ||
processes coeff_preds_list and uses fast NMS instead of traditional | ||
NMS.""" | ||
cfg = self.test_cfg if cfg is None else cfg | ||
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors) | ||
mlvl_bboxes = [] | ||
|
@@ -357,6 +355,12 @@ def _get_bboxes_single(self, | |
|
||
@HEADS.register_module() | ||
class YolactSegmHead(nn.Module): | ||
"""YOLACT segmentation head used in https://arxiv.org/abs/1904.02689. | ||
|
||
Apply a semantic segmentation loss on feature space using layers that are | ||
only evaluated during training to increase performance with no speed | ||
penalty. | ||
""" | ||
|
||
def __init__(self, | ||
in_channels=256, | ||
|
@@ -414,9 +418,9 @@ def loss(self, segm_pred, gt_masks, gt_labels): | |
|
||
@HEADS.register_module() | ||
class YolactProtonet(nn.Module): | ||
"""Protonet of Yolact. | ||
"""YOLACT mask head used in https://arxiv.org/abs/1904.02689. | ||
|
||
This head outputs the prototypes for Yolact. | ||
This head outputs the mask prototypes for YOLACT. | ||
""" | ||
|
||
def __init__( | ||
|
@@ -440,10 +444,7 @@ def __init__( | |
|
||
def make_net(self, in_channels, config, include_last_relu=True): | ||
"""A helper function to take a config setting and turn it into a | ||
network. | ||
|
||
Used by protonet and extrahead. Returns (network) | ||
""" | ||
network.""" | ||
|
||
def make_layer(layer_cfg): | ||
nonlocal in_channels | ||
|
@@ -520,6 +521,7 @@ def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None): | |
pos_inds = cur_sampling_results.pos_inds | ||
cur_coeff_pred = cur_coeff_pred[pos_inds] | ||
|
||
# Linearly combining the prototypes with the mask coefficients | ||
mask_pred = cur_prototypes @ cur_coeff_pred.t() | ||
mask_pred = torch.sigmoid(mask_pred) | ||
|
||
|
@@ -550,6 +552,9 @@ def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results): | |
num_pos = pos_assigned_gt_inds.size(0) | ||
if num_pos * cur_gt_masks.size(0) == 0: | ||
continue | ||
# Since we're producing (near) full image masks, | ||
# it'd take too much vram to backprop on every single mask. | ||
# Thus we select only a subset. | ||
if num_pos > self.max_masks_to_train: | ||
perm = torch.randperm(num_pos) | ||
select = perm[:self.max_masks_to_train] | ||
|
@@ -607,13 +612,13 @@ def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale): | |
return cls_segms | ||
|
||
def crop(self, masks, boxes, padding=1): | ||
""""Crop" predicted masks by zeroing out everything not in the | ||
predicted bbox. | ||
"""Crop predicted masks by zeroing out everything not in the predicted | ||
bbox. | ||
|
||
Args: | ||
masks should be a size [h, w, n] tensor of masks | ||
masks should be a size [h, w, n] tensor of masks. | ||
boxes should be a size [n, 4] tensor of bbox coords in | ||
relative point form | ||
relative point form. | ||
""" | ||
h, w, n = masks.size() | ||
x1, x2 = self.sanitize_coordinates( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ | |
|
||
@DETECTORS.register_module() | ||
class Yolact(SingleStageDetector): | ||
yhcao6 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Implementation of `YOLACT <https://arxiv.org/abs/1904.02689>`_""" | ||
|
||
def __init__(self, | ||
backbone, | ||
|
@@ -26,12 +27,14 @@ def __init__(self, | |
self.min_gt_box_wh = train_cfg.min_gt_box_wh | ||
|
||
def init_segm_mask_weights(self): | ||
"""Initialize weights of the YOLACT semg head and YOLACT mask head.""" | ||
self.segm_head.init_weights() | ||
self.mask_head.init_weights() | ||
|
||
def process_gt_single(self, gt_bboxes, gt_labels, gt_masks, min_gt_box_wh, | ||
device): | ||
# Cuda the gt_masks and discard boxes that are smaller than we'd like | ||
"""Cuda the gt_masks and discard boxes that are smaller than we'd | ||
like.""" | ||
gt_masks = torch.from_numpy(gt_masks.masks).to(device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
w = gt_bboxes[:, 2] - gt_bboxes[:, 0] | ||
h = gt_bboxes[:, 3] - gt_bboxes[:, 1] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MMDet v2.0 rewrites the coordinate system. So we don't need
+1
anymore.