Skip to content

Commit

Permalink
Implemented Chong's vectorized cropping and refactored a bunch of the…
Browse files Browse the repository at this point in the history
… postprocessing code to accommodate it better. As a side effect, not only it the model 3+ms faster, but it's also more accurate! (I now crop the original prototype resolution instead of the upsampled version, which saves time and adds performance).
  • Loading branch information
dbolya committed Mar 19, 2019
1 parent 6cdf562 commit 89fabd0
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 53 deletions.
2 changes: 1 addition & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from utils.functions import MovingAverage, ProgressBar
from layers.box_utils import jaccard, center_size
from utils import timer
from utils.functions import sanitize_coordinates, SavePath
from utils.functions import SavePath
from layers.output_utils import postprocess, undo_image_transformation
import pycocotools

Expand Down
49 changes: 49 additions & 0 deletions layers/box_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,3 +338,52 @@ def nms(boxes, scores, overlap=0.5, top_k=200, force_cpu=True):

timer.stop()
return keep, count


def sanitize_coordinates(_x1, _x2, img_size, cast=True):
"""
Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size.
Also converts from relative to absolute coordinates and casts the results to long tensors.
If cast is false, the result won't be cast to longs.
Warning: this does things in-place behind the scenes so copy if necessary.
"""
_x1 *= img_size
_x2 *= img_size
if cast:
_x1 = _x1.long()
_x2 = _x2.long()
x1 = torch.min(_x1, _x2)
x2 = torch.max(_x1, _x2)
x1 = torch.clamp(x1-1, min=0)
x2 = torch.clamp(x2+1, max=img_size)

return x1, x2

def crop(masks, boxes):
"""
"Crop" predicted masks by zeroing out everything not in the predicted bbox.
Vectorized by Chong (thanks Chong).
Args:
- masks should be a size [h, w, n] tensor of masks
- boxes should be a size [n, 4] tensor of bbox coords in relative point form
"""
with torch.no_grad():
h, w, n = masks.size()
boxes = boxes.clone() # Some in-place stuff goes on here
x1, x2 = sanitize_coordinates(boxes[:, 0], boxes[:, 2], w, cast=True)
y1, y2 = sanitize_coordinates(boxes[:, 1], boxes[:, 3], h, cast=True)

rows = torch.arange(w, device=masks.device)[None, :, None].expand(h, w, n)
cols = torch.arange(h, device=masks.device)[:, None, None].expand(h, w, n)

masks_left = rows >= x1[None, None, :]
masks_right = rows < x2[None, None, :]
masks_up = cols >= y1[None, None, :]
masks_down = cols < y2[None, None, :]

crop_mask = masks_left * masks_right * masks_up * masks_down

return masks * crop_mask.float()

22 changes: 4 additions & 18 deletions layers/modules/multibox_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from ..box_utils import match, log_sum_exp, decode, center_size
from utils.functions import sanitize_coordinates
from ..box_utils import match, log_sum_exp, decode, center_size, crop

from data import cfg, mask_type, activation_func

Expand Down Expand Up @@ -520,20 +519,7 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data,
loss_m += cfg.mask_proto_double_loss_alpha * pre_loss

if cfg.mask_proto_crop:
# Shortening the variable name here
expnd = cfg.mask_proto_crop_expand

# Take care of all the bad behavior that can be caused by out of bounds coordinates
# Note I tried to put expand here but the loss exploded for some reason
x1, x2 = sanitize_coordinates(pos_gt_box_t[:, 0], pos_gt_box_t[:, 2], mask_w)
y1, y2 = sanitize_coordinates(pos_gt_box_t[:, 1], pos_gt_box_t[:, 3], mask_h)

# "Crop" predicted masks by zeroing out everything not in the predicted bbox
# TODO: Write a cuda implementation of this to get rid of the loop
crop_mask = torch.zeros(mask_h, mask_w, num_pos)
for jdx in range(num_pos):
crop_mask[y1[jdx]*(1-expnd):y2[jdx]*(1+expnd), x1[jdx]*(1-expnd):x2[jdx]*(1+expnd), jdx] = 1
pred_masks = pred_masks * crop_mask
pred_masks = crop(pred_masks, pos_gt_box_t)

if cfg.mask_proto_mask_activation == activation_func.sigmoid:
pre_loss = F.binary_cross_entropy(pred_masks, mask_t, reduction='none')
Expand All @@ -550,8 +536,8 @@ def lincomb_mask_loss(self, pos, idx_t, loc_data, mask_data, priors, proto_data,
if cfg.mask_proto_normalize_emulate_roi_pooling:
weight = mask_h * mask_w if cfg.mask_proto_crop else 1
pos_get_csize = center_size(pos_gt_box_t)
gt_box_width = pos_get_csize[:, 2]
gt_box_height = pos_get_csize[:, 3]
gt_box_width = pos_get_csize[:, 2] * mask_w
gt_box_height = pos_get_csize[:, 3] * mask_h
pre_loss = pre_loss.sum(dim=(0, 1)) / gt_box_width / gt_box_height * weight


Expand Down
26 changes: 11 additions & 15 deletions layers/output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from data import cfg, mask_type, MEANS, STD, activation_func
from utils.augmentations import Resize
from utils.functions import sanitize_coordinates
from utils import timer
from .box_utils import crop, sanitize_coordinates

def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
visualize_lincomb=False, crop_masks=True):
Expand Down Expand Up @@ -69,9 +69,6 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
# Actually extract everything from dets now
classes = dets[:, 0].int()
boxes = dets[:, 2:6]
x1, x2 = sanitize_coordinates(boxes[:, 0], boxes[:, 2], b_w, cast=True)
y1, y2 = sanitize_coordinates(boxes[:, 1], boxes[:, 3], b_h, cast=True)
boxes = torch.stack((x1, y1, x2, y2), dim=1)
scores = dets[:, 1]
masks = dets[:, 6:]

Expand All @@ -87,26 +84,21 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
display_lincomb(proto_data, masks)

masks = torch.matmul(proto_data, masks.t())

masks = cfg.mask_proto_mask_activation(masks)

# Crop masks before upsampling because you know why
if crop_masks:
masks = crop(masks, boxes)

# Permute into the correct output shape [num_dets, proto_h, proto_w]
masks = masks.permute(2, 0, 1).contiguous()
masks = cfg.mask_proto_mask_activation(masks)

# Scale masks up to the full image
if cfg.preserve_aspect_ratio:
# Undo padding
masks = masks[:, :int(r_h/cfg.max_size*proto_data.size(1)), :int(r_w/cfg.max_size*proto_data.size(2))]
masks = F.interpolate(masks.unsqueeze(0), (h, w), mode=interpolation_mode, align_corners=False).squeeze(0)

# "Crop" predicted masks by zeroing out everything not in the predicted bbox
# TODO: Write a cuda implementation of this to get rid of the loop
if crop_masks:
num_dets = boxes.size(0)
crop_mask = torch.zeros(num_dets, h, w, device=masks.device)
for jdx in range(num_dets):
crop_mask[jdx, y1[jdx]:y2[jdx], x1[jdx]:x2[jdx]] = 1
masks = masks * crop_mask

# Binarize the masks
masks = masks.gt(0.5).float()

Expand All @@ -131,6 +123,10 @@ def postprocess(det_output, w, h, batch_idx=0, interpolation_mode='bilinear',
full_masks[jdx, y1:y2, x1:x2] = mask

masks = full_masks

boxes[:, 0], boxes[:, 2] = sanitize_coordinates(boxes[:, 0], boxes[:, 2], b_w, cast=False)
boxes[:, 1], boxes[:, 3] = sanitize_coordinates(boxes[:, 1], boxes[:, 3], b_h, cast=False)
boxes = boxes.long()

return classes, scores, boxes, masks

Expand Down
19 changes: 0 additions & 19 deletions utils/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,25 +80,6 @@ def __repr__(self):
def __str__(self):
return self.string

def sanitize_coordinates(_x1, _x2, img_size, cast=True):
"""
Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0, and x2 <= image_size.
Also converts from relative to absolute coordinates and casts the results to long tensors.
If cast is false, the result won't be cast to longs.
"""
_x1 *= img_size
_x2 *= img_size
if cast:
_x1 = _x1.long()
_x2 = _x2.long()
x1 = torch.min(_x1, _x2)
x2 = torch.max(_x1, _x2)
x1 = torch.clamp(x1-1, min=0)
x2 = torch.clamp(x2+1, max=img_size)

return x1, x2


def init_console():
"""
Expand Down

0 comments on commit 89fabd0

Please sign in to comment.