Skip to content

Commit

Permalink
Merge 0e71a74 into 8cac7c2
Browse files Browse the repository at this point in the history
  • Loading branch information
DCNSW authored Sep 29, 2021
2 parents 8cac7c2 + 0e71a74 commit a6a7156
Show file tree
Hide file tree
Showing 12 changed files with 686 additions and 1 deletion.
2 changes: 2 additions & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ We implement common CUDA ops used in detection, segmentation, etc.
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- MaskedConv
- NMS
Expand Down
2 changes: 2 additions & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- CornerPool
- Deformable Convolution v1/v2
- Deformable RoIPool
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- MaskedConv
- NMS
Expand Down
6 changes: 5 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
furthest_point_sample_with_dist)
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
Expand All @@ -29,6 +31,7 @@
from .pixel_group import pixel_group
from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point)
from .points_sampler import PointsSampler
from .psa_mask import PSAMask
from .roi_align import RoIAlign, roi_align
from .roi_align_rotated import RoIAlignRotated, roi_align_rotated
Expand All @@ -55,5 +58,6 @@
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'Correlation'
'furthest_point_sample', 'furthest_point_sample_with_dist',
'PointsSampler', 'Correlation'
]
152 changes: 152 additions & 0 deletions mmcv/ops/csrc/common/cuda/furthest_point_sample_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
// Copyright (c) OpenMMLab. All rights reserved
#ifndef FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH
#define FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))

__device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
int idx1, int idx2) {
const float v1 = dists[idx1], v2 = dists[idx2];
const int i1 = dists_i[idx1], i2 = dists_i[idx2];
dists[idx1] = max(v1, v2);
dists_i[idx1] = v2 > v1 ? i2 : i1;
}

template <unsigned int block_size>
__global__ void furthest_point_sampling_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)

if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];

int batch_index = blockIdx.x;
dataset += batch_index * n * 3;
temp += batch_index * n;
idxs += batch_index * m;

int tid = threadIdx.x;
const int stride = block_size;

int old = 0;
if (threadIdx.x == 0) idxs[0] = old;

__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
float x1 = dataset[old * 3 + 0];
float y1 = dataset[old * 3 + 1];
float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
float x2, y2, z2;
x2 = dataset[k * 3 + 0];
y2 = dataset[k * 3 + 1];
z2 = dataset[k * 3 + 2];
// float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
// if (mag <= 1e-3)
// continue;

float d =
(x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();

for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();
}

old = dists_i[0];
if (tid == 0) idxs[j] = old;
}
}

// Modified from
// https://github.com/qiqihaer/3DSSD-pytorch/blob/master/lib/pointnet2/src/sampling_gpu.cu
template <unsigned int block_size>
__global__ void furthest_point_sampling_with_dist_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ dataset,
float *__restrict__ temp, int *__restrict__ idxs) {
// dataset: (B, N, N)
// tmp: (B, N)
// output:
// idx: (B, M)

if (m <= 0) return;
__shared__ float dists[block_size];
__shared__ int dists_i[block_size];

int batch_index = blockIdx.x;
dataset += batch_index * n * n;
temp += batch_index * n;
idxs += batch_index * m;

int tid = threadIdx.x;
const int stride = block_size;

int old = 0;
if (threadIdx.x == 0) idxs[0] = old;

__syncthreads();
for (int j = 1; j < m; j++) {
int besti = 0;
float best = -1;
// float x1 = dataset[old * 3 + 0];
// float y1 = dataset[old * 3 + 1];
// float z1 = dataset[old * 3 + 2];
for (int k = tid; k < n; k += stride) {
// float x2, y2, z2;
// x2 = dataset[k * 3 + 0];
// y2 = dataset[k * 3 + 1];
// z2 = dataset[k * 3 + 2];

// float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) *
// (z2 - z1);
float d = dataset[old * n + k];

float d2 = min(d, temp[k]);
temp[k] = d2;
besti = d2 > best ? k : besti;
best = d2 > best ? d2 : best;
}
dists[tid] = best;
dists_i[tid] = besti;
__syncthreads();

for (int block_size_thres = 1024; block_size_thres >= 2;
block_size_thres /= 2) {
int tid_thres = block_size_thres / 2;
if (block_size >= block_size_thres) {
__update(dists, dists_i, tid, tid + tid_thres);
}
__syncthreads();
}

old = dists_i[0];
if (tid == 0) idxs[j] = old;
}
}

#endif // FURTHEST_POINT_SAMPLE_CUDA_KERNEL_CUH
143 changes: 143 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/furthest_point_sample_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
// Modified from
// https://github.com/sshaoshuai/Pointnet2.PyTorch/tree/master/pointnet2/src/sampling_gpu.cu

#include <stdio.h>
#include <stdlib.h>

#include "furthest_point_sample_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

inline int opt_n_threads(int work_size) {
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);

return max(min(1 << pow_2, 1024), 1);
}

void FurthestPointSamplingForwardCUDAKernelLauncher(int b, int n, int m,
const float *dataset,
float *temp, int *idxs) {
// dataset: (B, N, 3)
// tmp: (B, N)
// output:
// idx: (B, M)

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

unsigned int n_threads = opt_n_threads(n);

switch (n_threads) {
case 1024:
furthest_point_sampling_forward_cuda_kernel<1024>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_forward_cuda_kernel<256>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_forward_cuda_kernel<128>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_forward_cuda_kernel<64>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_forward_cuda_kernel<32>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_forward_cuda_kernel<16>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_forward_cuda_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_forward_cuda_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_forward_cuda_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_forward_cuda_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}

AT_CUDA_CHECK(cudaGetLastError());
}

void FurthestPointSamplingWithDistForwardCUDAKernelLauncher(
int b, int n, int m, const float *dataset, float *temp, int *idxs) {
// dataset: (B, N, N)
// temp: (B, N)
// output:
// idx: (B, M)

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

unsigned int n_threads = opt_n_threads(n);

switch (n_threads) {
case 1024:
furthest_point_sampling_with_dist_forward_cuda_kernel<1024>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 512:
furthest_point_sampling_with_dist_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 256:
furthest_point_sampling_with_dist_forward_cuda_kernel<256>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 128:
furthest_point_sampling_with_dist_forward_cuda_kernel<128>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 64:
furthest_point_sampling_with_dist_forward_cuda_kernel<64>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 32:
furthest_point_sampling_with_dist_forward_cuda_kernel<32>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 16:
furthest_point_sampling_with_dist_forward_cuda_kernel<16>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 8:
furthest_point_sampling_with_dist_forward_cuda_kernel<8>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 4:
furthest_point_sampling_with_dist_forward_cuda_kernel<4>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 2:
furthest_point_sampling_with_dist_forward_cuda_kernel<2>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
case 1:
furthest_point_sampling_with_dist_forward_cuda_kernel<1>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
break;
default:
furthest_point_sampling_with_dist_forward_cuda_kernel<512>
<<<b, n_threads, 0, stream>>>(b, n, m, dataset, temp, idxs);
}

AT_CUDA_CHECK(cudaGetLastError());
}
Loading

0 comments on commit a6a7156

Please sign in to comment.