Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add box_iou_rotated, ml_nms_rotated and nms_rotated #625

Merged
merged 28 commits into from
Nov 25, 2020
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bbox import bbox_overlaps
from .box_iou_rotated import box_iou_rotated
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention
from .corner_pool import CornerPool
Expand All @@ -17,6 +18,7 @@
ModulatedDeformConv2dPack,
modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, soft_nms
from .nms_rotated import ml_nms_rotated, nms_rotated
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask
Expand All @@ -38,5 +40,6 @@
'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift'
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'ml_nms_rotated',
magicdream2222 marked this conversation as resolved.
Show resolved Hide resolved
'nms_rotated'
]
15 changes: 15 additions & 0 deletions mmcv/ops/box_iou_rotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['box_iou_rotated'])


def box_iou_rotated(bboxes1, bboxes2):
magicdream2222 marked this conversation as resolved.
Show resolved Hide resolved
if torch.__version__ == 'parrots':
out = torch.zeros((bboxes1.shape[0], bboxes2.shape[0]),
dtype=torch.float32).to(bboxes1.device)
ext_module.box_iou_rotated(bboxes1, bboxes2, out)
else:
out = ext_module.box_iou_rotated(bboxes1, bboxes2)
return out
67 changes: 67 additions & 0 deletions mmcv/ops/csrc/box_iou_rotated_cuda.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
#ifndef BOX_IOU_ROTATED_CUDA_CUH
#define BOX_IOU_ROTATED_CUDA_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif
#include "box_iou_rotated_utils.hpp"

// 2D block with 32 * 16 = 512 threads per block
const int BLOCK_DIM_X = 32;
const int BLOCK_DIM_Y = 16;

inline int divideUP(const int x, const int y) { return (((x) + (y)-1) / (y)); }

template <typename T>
__global__ void box_iou_rotated_cuda_kernel(const int n_boxes1,
const int n_boxes2,
const T* dev_boxes1,
const T* dev_boxes2, T* dev_ious) {
const int row_start = blockIdx.x * blockDim.x;
const int col_start = blockIdx.y * blockDim.y;

const int row_size = min(n_boxes1 - row_start, blockDim.x);
const int col_size = min(n_boxes2 - col_start, blockDim.y);

__shared__ float block_boxes1[BLOCK_DIM_X * 5];
__shared__ float block_boxes2[BLOCK_DIM_Y * 5];

// It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y
if (threadIdx.x < row_size && threadIdx.y == 0) {
block_boxes1[threadIdx.x * 5 + 0] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 0];
block_boxes1[threadIdx.x * 5 + 1] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 1];
block_boxes1[threadIdx.x * 5 + 2] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 2];
block_boxes1[threadIdx.x * 5 + 3] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 3];
block_boxes1[threadIdx.x * 5 + 4] =
dev_boxes1[(row_start + threadIdx.x) * 5 + 4];
}

if (threadIdx.x < col_size && threadIdx.y == 0) {
block_boxes2[threadIdx.x * 5 + 0] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 0];
block_boxes2[threadIdx.x * 5 + 1] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 1];
block_boxes2[threadIdx.x * 5 + 2] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 2];
block_boxes2[threadIdx.x * 5 + 3] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 3];
block_boxes2[threadIdx.x * 5 + 4] =
dev_boxes2[(col_start + threadIdx.x) * 5 + 4];
}
__syncthreads();

if (threadIdx.x < row_size && threadIdx.y < col_size) {
int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y;
dev_ious[offset] = single_box_iou_rotated<T>(
block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5);
}
}

#endif
Loading