Skip to content

Commit

Permalink
Add box_iou_rotated, ml_nms_rotated and nms_rotated (#625)
Browse files Browse the repository at this point in the history
* add box_iou_rotated, ml_nms_rotated and nms_rotated

* fix lint

* fix lint

* fix .py lint

* fix cpp lint

* add newline at the end

* add new line

* fix unittest

* config google style

* fix lint

* lint

* lint

* yapf

* update

* fix lint

* fix lint

* fix lint

* fix

* fix format

* fix format

* add modified from

* add docstring and update others

* update docstring

* update docstring

* update

* fix bug

* fix bug

* fix bug

Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
  • Loading branch information
magicdream2222 and yhcao6 authored Nov 25, 2020
1 parent f61bb64 commit a978764
Show file tree
Hide file tree
Showing 20 changed files with 1,123 additions and 3 deletions.
5 changes: 3 additions & 2 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bbox import bbox_overlaps
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .corner_pool import CornerPool
Expand All @@ -16,7 +17,7 @@
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, soft_nms
from .nms import batched_nms, nms, nms_match, nms_rotated, soft_nms
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask
Expand All @@ -38,5 +39,5 @@
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift'
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated'
]
30 changes: 30 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])


def box_iou_rotated(bboxes1, bboxes2):
"""Return intersection-over-union (Jaccard index) of boxes.
Both sets of boxes are expected to be in
(x_center, y_center, width, height, angle) format.
Arguments:
boxes1 (Tensor): rotated bboxes 1. \
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
boxes2 (Tensor): rotated bboxes 2. \
It has shape (N, 5), indicating (x, y, w, h, theta) for each row.
Returns:
iou (Tensor[N, M]): the NxM matrix containing the pairwise
IoU values for every element in boxes1 and boxes2
"""
if torch.__version__ == 'parrots':
out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]),
dtype=torch.float32).to(bboxes1.device)
ext_module.box_iou_rotated(bboxes1, bboxes2, out)
else:
out = ext_module.box_iou_rotated(bboxes1, bboxes2)
return out
69 changes: 69 additions & 0 deletions mmcv/ops/csrc/box_iou_rotated_cuda.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
// modified from
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu
#ifndef BOX_IOU_ROTATED_CUDA_CUH
#define BOX_IOU_ROTATED_CUDA_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"

// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;

inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }

template <typename T>
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
const int n_boxes2,
const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious) {
const int row_start = blockIdx.x * blockDim.x;
const int col_start = blockIdx.y * blockDim.y;

const int row_size = min(n_boxes1 - row_start, blockDim.x);
const int col_size = min(n_boxes2 - col_start, blockDim.y);

__shared__ float block_boxes1[BLOCK_DIM_X * 5];
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];

// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < row_size && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
block_boxes1[threadIdx.x * 5 + 1] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
block_boxes1[threadIdx.x * 5 + 2] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
block_boxes1[threadIdx.x * 5 + 3] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
block_boxes1[threadIdx.x * 5 + 4] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
}

if (threadIdx.x < col_size && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
block_boxes2[threadIdx.x * 5 + 1] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
block_boxes2[threadIdx.x * 5 + 2] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
block_boxes2[threadIdx.x * 5 + 3] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
block_boxes2[threadIdx.x * 5 + 4] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();

if (threadIdx.x < row_size && threadIdx.y < col_size) {
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
dev_ious[offset] = single_box_iou_rotated<T>(
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
}
}

#endif
Loading

0 comments on commit a978764

Please sign in to comment.