-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add box_iou_rotated, ml_nms_rotated and nms_rotated (#625)
* add box_iou_rotated, ml_nms_rotated and nms_rotated * fix lint * fix lint * fix .py lint * fix cpp lint * add newline at the end * add new line * fix unittest * config google style * fix lint * lint * lint * yapf * update * fix lint * fix lint * fix lint * fix * fix format * fix format * add modified from * add docstring and update others * update docstring * update docstring * update * fix bug * fix bug * fix bug Co-authored-by: Cao Yuhang <yhcao6@gmail.com>
- Loading branch information
1 parent
f61bb64
commit a978764
Showing
20 changed files
with
1,123 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import torch | ||
|
||
from ..utils import ext_loader | ||
|
||
ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated']) | ||
|
||
|
||
def box_iou_rotated(bboxes1, bboxes2): | ||
"""Return intersection-over-union (Jaccard index) of boxes. | ||
Both sets of boxes are expected to be in | ||
(x_center, y_center, width, height, angle) format. | ||
Arguments: | ||
boxes1 (Tensor): rotated bboxes 1. \ | ||
It has shape (N, 5), indicating (x, y, w, h, theta) for each row. | ||
boxes2 (Tensor): rotated bboxes 2. \ | ||
It has shape (N, 5), indicating (x, y, w, h, theta) for each row. | ||
Returns: | ||
iou (Tensor[N, M]): the NxM matrix containing the pairwise | ||
IoU values for every element in boxes1 and boxes2 | ||
""" | ||
if torch.__version__ == 'parrots': | ||
out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]), | ||
dtype=torch.float32).to(bboxes1.device) | ||
ext_module.box_iou_rotated(bboxes1, bboxes2, out) | ||
else: | ||
out = ext_module.box_iou_rotated(bboxes1, bboxes2) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
// modified from | ||
// https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu | ||
#ifndef BOX_IOU_ROTATED_CUDA_CUH | ||
#define BOX_IOU_ROTATED_CUDA_CUH | ||
|
||
#ifdef MMCV_USE_PARROTS | ||
#include "parrots_cuda_helper.hpp" | ||
#else | ||
#include "pytorch_cuda_helper.hpp" | ||
#endif | ||
#include "box_iou_rotated_utils.hpp" | ||
|
||
// 2D block with 32 * 16 = 512 threads per block | ||
const int BLOCK_DIM_X = 32; | ||
const int BLOCK_DIM_Y = 16; | ||
|
||
inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); } | ||
|
||
template <typename T> | ||
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1, | ||
const int n_boxes2, | ||
const T* dev_boxes1, | ||
const T* dev_boxes2, T* dev_ious) { | ||
const int row_start = blockIdx.x * blockDim.x; | ||
const int col_start = blockIdx.y * blockDim.y; | ||
|
||
const int row_size = min(n_boxes1 - row_start, blockDim.x); | ||
const int col_size = min(n_boxes2 - col_start, blockDim.y); | ||
|
||
__shared__ float block_boxes1[BLOCK_DIM_X * 5]; | ||
__shared__ float block_boxes2[BLOCK_DIM_Y * 5]; | ||
|
||
// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y | ||
if (threadIdx.x < row_size && threadIdx.y == 0) { | ||
block_boxes1[threadIdx.x * 5 + 0] = | ||
dev_boxes1[(row_start + threadIdx.x) * 5 + 0]; | ||
block_boxes1[threadIdx.x * 5 + 1] = | ||
dev_boxes1[(row_start + threadIdx.x) * 5 + 1]; | ||
block_boxes1[threadIdx.x * 5 + 2] = | ||
dev_boxes1[(row_start + threadIdx.x) * 5 + 2]; | ||
block_boxes1[threadIdx.x * 5 + 3] = | ||
dev_boxes1[(row_start + threadIdx.x) * 5 + 3]; | ||
block_boxes1[threadIdx.x * 5 + 4] = | ||
dev_boxes1[(row_start + threadIdx.x) * 5 + 4]; | ||
} | ||
|
||
if (threadIdx.x < col_size && threadIdx.y == 0) { | ||
block_boxes2[threadIdx.x * 5 + 0] = | ||
dev_boxes2[(col_start + threadIdx.x) * 5 + 0]; | ||
block_boxes2[threadIdx.x * 5 + 1] = | ||
dev_boxes2[(col_start + threadIdx.x) * 5 + 1]; | ||
block_boxes2[threadIdx.x * 5 + 2] = | ||
dev_boxes2[(col_start + threadIdx.x) * 5 + 2]; | ||
block_boxes2[threadIdx.x * 5 + 3] = | ||
dev_boxes2[(col_start + threadIdx.x) * 5 + 3]; | ||
block_boxes2[threadIdx.x * 5 + 4] = | ||
dev_boxes2[(col_start + threadIdx.x) * 5 + 4]; | ||
} | ||
__syncthreads(); | ||
|
||
if (threadIdx.x < row_size && threadIdx.y < col_size) { | ||
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y; | ||
dev_ious[offset] = single_box_iou_rotated<T>( | ||
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); | ||
} | ||
} | ||
|
||
#endif |
Oops, something went wrong.