diff --git a/mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu deleted file mode 100644 index 84e53aa1f3..0000000000 --- a/mmcv/ops/csrc/common/mlu/iou3d_mlu_kernel.mlu +++ /dev/null @@ -1,431 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ - -#include "common_mlu_helper.hpp" -#include "iou3d_utils.hpp" - -#define SIZE_SRAM_BUF (MAX_SRAM_SIZE) - -/* NRAM buffer - * Suppose deal N boxes once time. ----------------------------------------------------------------- -| Basic |score (1N)+ |intersect_pts(48N)| | -| |valid_box(1N) |+ ordered_pts(48N)| temp_long(72N) | -| |+ temp_buffer(10N)| | | -|--------------------------|------------------|----------------| -| Reuse | null | null |rotated_pts(16N)| -|-------|------------------|------------------|----------------| - ---------------------------------------------------------------------------- -| Basic | dist_ram(24N) | valid_pts(24N) |box1(5N) |box1_buffer(5KB) | -| | |+ nums_in_ram(1N)|+ box2(5N)|+nram_save(5KB) | -|--------------------------|-----------------|----------|-----------------| -| Reuse | vec_buffer(5N) | null | null | null | -|-------|------------------|-----------------|----------|-----------------| -Total Basic Memory Size = 239N * sizeof(float) + 10KB -*/ - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; -__mlu_shared__ char sram_buffer[SIZE_SRAM_BUF]; - -template -__mlu_func__ void iou3D_detection(int32_t &result_box_num, int32_t *output_data, - const T *boxes_data, float *scores_data, - const int core_limit, const int input_box_num, - const float iou_threshold, - mluMemcpyDirection_t scores_load_dir, - mluMemcpyDirection_t scores_store_dir, - mluMemcpyDirection_t boxes_load_dir) { - // NRAM divide by (2+4*COMPUTE_COUNT_ALIGN) copies of NRAM, counted by bytes - const int nram_save_limit_count = 256; - int box_read_limit_count = 256; - float div_thresh_iou = 1.0 / iou_threshold; - // every box require 239 * sizeof(float) space in nram; - const int32_t copies_of_nram = 239 * sizeof(float); - const int32_t limit = (MAX_NRAM_SIZE - 5 * box_read_limit_count * sizeof(T) - - nram_save_limit_count * sizeof(int32_t)) / - copies_of_nram; - - // x,y,z,dx,dy,dz,angle - const T *input_x_ptr = boxes_data; - const T *input_y_ptr = input_x_ptr + input_box_num; - const T *input_dx_ptr = input_y_ptr + 2 * input_box_num; - const T *input_dy_ptr = input_dx_ptr + input_box_num; - const T *input_angle_ptr = input_dy_ptr + 2 * input_box_num; - float *input_score_ptr = scores_data; - - // data split - int avg_cluster = 0; - int rem_cluster = 0; - int len_cluster = 0; - int cluster_offset = 0; - if (clusterDim > 0) { - // union - avg_cluster = input_box_num / clusterDim; - rem_cluster = input_box_num % clusterDim; - len_cluster = avg_cluster + (clusterId < rem_cluster ? 1 : 0); - cluster_offset = avg_cluster * clusterId + - (clusterId <= rem_cluster ? clusterId : rem_cluster); - } else { - // block - len_cluster = input_box_num; - cluster_offset = 0; - } - int len_core = input_box_num; - int input_offset = 0; - if (core_limit > 1) { - int avg_core = len_cluster / coreDim; - int rem_core = len_cluster % coreDim; - len_core = avg_core + (coreId < rem_core ? 1 : 0); - int core_offset = - avg_core * coreId + (coreId <= rem_core ? coreId : rem_core); - input_offset = cluster_offset + core_offset; - } - - int32_t max_seg_pad = IOU3D_DOWN(limit, IOU3D_SIZE); - int repeat_iou_compute = len_core / max_seg_pad; - int remain_iou_compute = len_core % max_seg_pad; - - // basic consistent memory layout - void *score = ((char *)nram_buffer); - void *valid_box = ((char *)score) + 1 * max_seg_pad * sizeof(float); - void *temp_buffer = ((char *)valid_box) + 1 * max_seg_pad * sizeof(float); - void *intersect_pts_x = - ((char *)temp_buffer) + 10 * max_seg_pad * sizeof(float); - void *intersect_pts_y = - ((char *)intersect_pts_x) + 24 * max_seg_pad * sizeof(float); - void *ordered_pts_x = - ((char *)intersect_pts_y) + 24 * max_seg_pad * sizeof(float); - void *ordered_pts_y = - ((char *)ordered_pts_x) + 24 * max_seg_pad * sizeof(float); - void *temp_long_1 = - ((char *)ordered_pts_y) + 24 * max_seg_pad * sizeof(float); - void *temp_long_2 = ((char *)temp_long_1) + 24 * max_seg_pad * sizeof(float); - void *temp_long_3 = ((char *)temp_long_2) + 24 * max_seg_pad * sizeof(float); - void *dist_ram = ((char *)temp_long_3) + 24 * max_seg_pad * sizeof(float); - void *valid_pts = ((char *)dist_ram) + 24 * max_seg_pad * sizeof(float); - void *nums_in_ram = ((char *)valid_pts) + 24 * max_seg_pad * sizeof(float); - T *box1 = (T *)(((char *)nums_in_ram) + 1 * max_seg_pad * sizeof(float)); - T *box2 = (T *)(((char *)box1) + 5 * max_seg_pad * sizeof(float)); - void *box1_buffer = ((char *)box2) + 5 * max_seg_pad * sizeof(float); - int32_t *nram_save = - (int32_t *)(((char *)box1_buffer) + 5 * box_read_limit_count * sizeof(T)); - // nram_save ~ nram_save_limit_count * sizeof(int32_t) - int nram_save_count = 0; - - // reuse memory - void *rotated_pts1_x = ((char *)dist_ram); - void *rotated_pts1_y = - ((char *)rotated_pts1_x) + 4 * max_seg_pad * sizeof(float); - void *rotated_pts2_x = - ((char *)rotated_pts1_y) + 4 * max_seg_pad * sizeof(float); - void *rotated_pts2_y = - ((char *)rotated_pts2_x) + 4 * max_seg_pad * sizeof(float); - void *vec_buffer = ((char *)temp_long_1) + 5 * max_seg_pad * sizeof(float); - // vec_buffer ~ 16 * max_seg_pad * sizeof(float) - - // First, initialize ram with all 0, or could cause nan/inf unexcepted results - __bang_write_zero((unsigned char *)nram_buffer, copies_of_nram * max_seg_pad); - // number 8 and 0xff relay on box_read_limit_count initial as 256 - const int max_box_seg_id = (input_box_num - 1) >> 8; - const int last_rem_box_number = ((input_box_num - 1) & 0xff) + 1; - for (int32_t cur_box = 0; cur_box < input_box_num; ++cur_box) { - __sync_all(); - int box_seg_id = cur_box >> 8, box_id = cur_box & 0xff; - box_read_limit_count = box_seg_id == max_box_seg_id ? last_rem_box_number - : box_read_limit_count; - if (box_id == 0) { - // x,y,z,dx,dy,dz,angle - int offset_num = box_seg_id << 8; - // x - __memcpy((char *)box1_buffer, input_x_ptr + offset_num, - box_read_limit_count * 1 * sizeof(T), boxes_load_dir, - box_read_limit_count * 1 * sizeof(T), - box_read_limit_count * 1 * sizeof(T), 0); - // y - __memcpy((char *)box1_buffer + box_read_limit_count * 1 * sizeof(T), - input_y_ptr + offset_num, box_read_limit_count * 1 * sizeof(T), - boxes_load_dir, box_read_limit_count * 1 * sizeof(T), - box_read_limit_count * 1 * sizeof(T), 0); - // dx - __memcpy((char *)box1_buffer + box_read_limit_count * 2 * sizeof(T), - input_dx_ptr + offset_num, box_read_limit_count * 1 * sizeof(T), - boxes_load_dir, box_read_limit_count * 1 * sizeof(T), - box_read_limit_count * 1 * sizeof(T), 0); - // dy - __memcpy((char *)box1_buffer + box_read_limit_count * 3 * sizeof(T), - input_dy_ptr + offset_num, box_read_limit_count * 1 * sizeof(T), - boxes_load_dir, box_read_limit_count * 1 * sizeof(T), - box_read_limit_count * 1 * sizeof(T), 0); - // angle - __memcpy((char *)box1_buffer + box_read_limit_count * 4 * sizeof(T), - input_angle_ptr + offset_num, - box_read_limit_count * 1 * sizeof(T), boxes_load_dir, - box_read_limit_count * 1 * sizeof(T), - box_read_limit_count * 1 * sizeof(T), 0); - } - if (((float *)input_score_ptr)[cur_box] == 0) { - continue; - } - // save result - nram_save[nram_save_count] = cur_box; - result_box_num++; - nram_save_count++; - if (clusterId == 0 && coreId == 0 && - nram_save_count == nram_save_limit_count) { - pvLock(); - __memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t), - NRAM2GDRAM); - pvUnlock(); - output_data += nram_save_count; - nram_save_count = 0; - } - // prepare box1 - // x - __bang_write_value((float *)box1, max_seg_pad, - float(((T *)box1_buffer)[box_id])); - // y - __bang_write_value( - (float *)box1 + max_seg_pad, max_seg_pad, - float(((T *)box1_buffer)[box_id + 1 * box_read_limit_count])); - // dx - __bang_write_value( - (float *)box1 + max_seg_pad * 2, max_seg_pad, - float(((T *)box1_buffer)[box_id + 2 * box_read_limit_count])); - // dy - __bang_write_value( - (float *)box1 + max_seg_pad * 3, max_seg_pad, - float(((T *)box1_buffer)[box_id + 3 * box_read_limit_count])); - // angle - __bang_write_value( - (float *)box1 + max_seg_pad * 4, max_seg_pad, - float(((T *)box1_buffer)[box_id + 4 * box_read_limit_count])); - - float max_area = 1.0f * - ((T *)box1_buffer)[box_id + 2 * box_read_limit_count] * - ((T *)box1_buffer)[box_id + 3 * box_read_limit_count]; - // update score - - for (int i = 0; i <= repeat_iou_compute; i++) { - if (i == repeat_iou_compute && remain_iou_compute == 0) { - break; - } - int seg_len = max_seg_pad; - int cpy_len = - (i == repeat_iou_compute) ? remain_iou_compute : max_seg_pad; - // int half_offset = std::is_same::value ? max_seg_pad * 5 : 0; - int half_offset = (sizeof(T) == sizeof(half)) ? max_seg_pad * 5 : 0; - // score - __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, - cpy_len * sizeof(float), scores_load_dir, - cpy_len * sizeof(float), cpy_len * sizeof(float), 0); - // x - __memcpy(box2 + half_offset, input_x_ptr + input_offset + i * max_seg_pad, - cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T), - cpy_len * 1 * sizeof(T), 0); - // y - __memcpy(box2 + half_offset + seg_len * 1, - input_y_ptr + input_offset + i * max_seg_pad, - cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T), - cpy_len * 1 * sizeof(T), 0); - // dx - __memcpy(box2 + half_offset + seg_len * 2, - input_dx_ptr + input_offset + i * max_seg_pad, - cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T), - cpy_len * 1 * sizeof(T), 0); - // dy - __memcpy(box2 + half_offset + seg_len * 3, - input_dy_ptr + input_offset + i * max_seg_pad, - cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T), - cpy_len * 1 * sizeof(T), 0); - // angle - __memcpy(box2 + half_offset + seg_len * 4, - input_angle_ptr + input_offset + i * max_seg_pad, - cpy_len * 1 * sizeof(T), boxes_load_dir, cpy_len * 1 * sizeof(T), - cpy_len * 1 * sizeof(T), 0); - // if (std::is_same::value) { - if (sizeof(T) == sizeof(half)) { - __bang_half2float((float *)box2, (half *)(box2 + half_offset), - seg_len * 5); - } - - // Calculate rotated vertices - void *temp1_ram = ((char *)temp_buffer); - void *temp2_ram = ((char *)temp_buffer) + seg_len * sizeof(float); - void *temp3_ram = ((char *)temp_buffer) + 2 * seg_len * sizeof(float); - void *temp4_ram = ((char *)temp_buffer) + 3 * seg_len * sizeof(float); - getRotatedVertices((float *)rotated_pts1_x, (float *)rotated_pts1_y, - (float *)box1, (float *)temp1_ram, (float *)temp2_ram, - (float *)temp3_ram, (float *)temp4_ram, seg_len); - getRotatedVertices((float *)rotated_pts2_x, (float *)rotated_pts2_y, - (float *)box2, (float *)temp1_ram, (float *)temp2_ram, - (float *)temp3_ram, (float *)temp4_ram, seg_len); - - __bang_write_zero((float *)valid_pts, 24 * seg_len); - __bang_write_zero((float *)nums_in_ram, seg_len); - __bang_write_value(((float *)valid_box), seg_len, 1.0f); - void *vec1_x = ((char *)vec_buffer); - void *vec1_y = ((char *)vec1_x) + 4 * seg_len * sizeof(float); - void *vec2_x = ((char *)vec1_y) + 4 * seg_len * sizeof(float); - void *vec2_y = ((char *)vec2_x) + 4 * seg_len * sizeof(float); - void *temp5_ram = ((char *)temp_buffer) + 4 * seg_len * sizeof(float); - void *temp6_ram = ((char *)temp_buffer) + 5 * seg_len * sizeof(float); - void *temp7_ram = ((char *)temp_buffer) + 6 * seg_len * sizeof(float); - void *temp8_ram = ((char *)temp_buffer) + 7 * seg_len * sizeof(float); - void *temp9_ram = ((char *)temp_buffer) + 8 * seg_len * sizeof(float); - void *temp10_ram = ((char *)temp_buffer) + 9 * seg_len * sizeof(float); - - // Get all intersection points - getIntersectPts( - (float *)rotated_pts1_x, (float *)rotated_pts1_y, - (float *)rotated_pts2_x, (float *)rotated_pts2_y, (float *)vec1_x, - (float *)vec1_y, (float *)vec2_x, (float *)vec2_y, - (float *)intersect_pts_x, (float *)intersect_pts_y, - (float *)valid_pts, (float *)nums_in_ram, (float *)temp1_ram, - (float *)temp2_ram, (float *)temp3_ram, (float *)temp4_ram, - (float *)temp5_ram, (float *)temp6_ram, (float *)temp7_ram, - (float *)temp8_ram, (float *)temp9_ram, (float *)temp10_ram, seg_len); - - // Where nums_in <= 2, set valid_box to false - __bang_write_value((float *)temp9_ram, COMPUTE_COUNT_ALIGN, (float)2); - __bang_cycle_gt((float *)temp1_ram, (float *)nums_in_ram, - (float *)temp9_ram, seg_len, COMPUTE_COUNT_ALIGN); - __bang_and((float *)valid_box, (float *)valid_box, (float *)temp1_ram, - seg_len); - __bang_cycle_and((float *)valid_pts, (float *)valid_pts, - (float *)valid_box, 24 * seg_len, seg_len); - - // Convex-hull-graham to order the intersection points in clockwise order - // and find the contour area - - convexHullGraham( - (float *)intersect_pts_x, (float *)intersect_pts_y, - (float *)ordered_pts_x, (float *)ordered_pts_y, (float *)dist_ram, - (float *)valid_box, (float *)valid_pts, (float *)nums_in_ram, - (float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram, - (float *)temp_long_1, (float *)temp_long_2, (float *)temp_long_3, - seg_len, seg_len); - // Calculate polygon area - // set temp1 = intersection part area - polygonArea((float *)ordered_pts_x, (float *)ordered_pts_y, - (float *)valid_box, (float *)valid_pts, (float *)nums_in_ram, - (float *)temp1_ram, (float *)temp2_ram, (float *)temp3_ram, - (float *)temp4_ram, (float *)temp5_ram, (float *)temp6_ram, - (float *)temp7_ram, (float *)temp8_ram, (float *)temp9_ram, - seg_len); - // area - __bang_mul((float *)temp2_ram, (float *)box2 + seg_len * 2, - (float *)box2 + seg_len * 3, seg_len); - // get the area_U: area + max_area - area_I - __bang_add_scalar((float *)temp2_ram, (float *)temp2_ram, float(max_area), - seg_len); - __bang_sub((float *)temp2_ram, (float *)temp2_ram, (float *)temp1_ram, - seg_len); // area_U - if (iou_threshold > 0.0) { - __bang_mul_scalar((float *)temp1_ram, (float *)temp1_ram, - div_thresh_iou, seg_len); - } else { - __bang_mul_scalar((float *)temp2_ram, (float *)temp2_ram, iou_threshold, - seg_len); - } - __bang_ge((float *)temp1_ram, (float *)temp2_ram, (float *)temp1_ram, - seg_len); - __bang_mul((float *)score, (float *)score, (float *)temp1_ram, seg_len); - - pvLock(); - __memcpy(input_score_ptr + input_offset + i * max_seg_pad, score, - cpy_len * sizeof(float), scores_store_dir, - cpy_len * sizeof(float), cpy_len * sizeof(float), 0); - pvUnlock(); - } - } - if (clusterId == 0 && coreId == 0 && nram_save_count) { - pvLock(); - __memcpy(output_data, nram_save, nram_save_count * sizeof(int32_t), - NRAM2GDRAM); - pvUnlock(); - } -} -__mlu_global__ void MLUBlockorUnionIKernelOU3D( - const void *input_boxes, const int input_box_num, const float iou_threshold, - const cnrtDataType_t data_type_input, void *workspace, void *result_num, - void *output) { - int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2; - mluMemcpyDirection_t scores_load_dir = GDRAM2NRAM; - mluMemcpyDirection_t scores_store_dir = NRAM2GDRAM; - mluMemcpyDirection_t boxes_load_dir = GDRAM2NRAM; - float *scores_data = (float *)workspace; - float *boxes_data = (float *)input_boxes; - const int cluster_score_size = input_box_num * sizeof(float); - const int cluster_boxes_size = input_box_num * 7 * input_dwidth; - char *sram_score = (char *)sram_buffer; - char *sram_boxes = (char *)sram_buffer + cluster_score_size; - if (clusterDim == 1 && SIZE_SRAM_BUF > cluster_score_size) { - scores_data = (float *)sram_score; - scores_load_dir = SRAM2NRAM; - scores_store_dir = NRAM2SRAM; - if (coreId == 0x80) { - __sramset((void *)sram_buffer, input_box_num, 1.0f); - } - } else { - if (coreId == 0) { - __gdramset(scores_data, input_box_num, 1.0f); - } - } - if (clusterDim == 1 && - SIZE_SRAM_BUF - cluster_score_size >= cluster_boxes_size) { - boxes_load_dir = SRAM2NRAM; - boxes_data = (float *)sram_boxes; - if (coreId == 0x80) { - __memcpy((char *)boxes_data, (char *)input_boxes, cluster_boxes_size, - GDRAM2SRAM); - } - } - __sync_cluster(); - - int32_t result_box_num = 0; - int32_t *out_data = (int32_t *)output; - - switch (data_type_input) { - default: { return; } - case CNRT_FLOAT16: { - iou3D_detection(result_box_num, out_data, (half *)boxes_data, scores_data, - taskDim, input_box_num, iou_threshold, scores_load_dir, - scores_store_dir, boxes_load_dir); - }; break; - case CNRT_FLOAT32: { - iou3D_detection(result_box_num, out_data, boxes_data, scores_data, - taskDim, input_box_num, iou_threshold, scores_load_dir, - scores_store_dir, boxes_load_dir); - }; break; - } - ((int32_t *)result_num)[0] = result_box_num; -} - -void KernelIou3d(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t data_type_input, const void *boxes_dram, - const int input_box_num, const float iou_threshold, - void *workspace, void *output_size, void *output) { - switch (k_type) { - default: { return; } - case CNRT_FUNC_TYPE_BLOCK: - case CNRT_FUNC_TYPE_UNION1: - case CNRT_FUNC_TYPE_UNION2: - case CNRT_FUNC_TYPE_UNION4: - case CNRT_FUNC_TYPE_UNION8: - case CNRT_FUNC_TYPE_UNION16: { - MLUBlockorUnionIKernelOU3D<<>>( - (void *)boxes_dram, input_box_num, iou_threshold, data_type_input, - workspace, output_size, output); - }; break; - } -} diff --git a/mmcv/ops/csrc/common/mlu/iou3d_utils.hpp b/mmcv/ops/csrc/common/mlu/iou3d_utils.hpp deleted file mode 100644 index b98ffe2fca..0000000000 --- a/mmcv/ops/csrc/common/mlu/iou3d_utils.hpp +++ /dev/null @@ -1,695 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ - -#ifndef IOU3D_UTILS_HPP_ -#define IOU3D_UTILS_HPP_ -#include "common_mlu_helper.hpp" - -#define IOU3D_SIZE 64 -#define IOU3D_UP(x, y) (x / y + (int)(x % y > 0)) * y -#define IOU3D_DOWN(x, y) (x / y) * y -#define SIZE_NRAM_BUF (MAX_NRAM_SIZE) -#define SIZE_SRAM_BUF (MAX_SRAM_SIZE) -#define COMPUTE_COUNT_ALIGN 64 -#define INFO_NUM (5) // score, x1, y1, x2, y2 -#define REDUCE_NUM \ - (7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input) -#define SINGLE_BOX_DIM 5 -#define MEMORY_CORE (0x80) -__mlu_func__ void pvLock() { -#if __BANG_ARCH__ == 270 - if (coreId != MEMORY_CORE) { - __bang_lock(0, 0); - } -#endif -} - -__mlu_func__ void pvUnlock() { -#if __BANG_ARCH__ == 270 - if (coreId != MEMORY_CORE) { - __bang_unlock(0, 0); - } -#endif -} - -// cross2d(A, B) = A.x * B.y - A.y * B.x; -template -inline __mlu_func__ void cross2d(T *result, const T *p1_x, const T *p1_y, - const T *p2_x, const T *p2_y, - const int &length, T *temp_ram) { - __bang_mul((T *)temp_ram, (T *)p1_x, (T *)p2_y, length); - __bang_mul((T *)result, (T *)p1_y, (T *)p2_x, length); - __bang_sub((T *)result, (T *)temp_ram, (T *)result, length); -} - -// dot2d(A, B) = A.x * B.x + A.y * B.y -template -inline __mlu_func__ void dot2d(T *result, const T *p1_x, const T *p1_y, - const T *p2_x, const T *p2_y, const int &length, - T *temp_ram) { - __bang_mul((T *)temp_ram, (T *)p1_x, (T *)p2_x, length); - __bang_mul((T *)result, (T *)p1_y, (T *)p2_y, length); - __bang_add((T *)result, (T *)temp_ram, (T *)result, length); -} - -template -__mlu_func__ void getRotatedVertices(T *pts_x, T *pts_y, T *box, T *temp1, - T *temp2, T *temp3, T *temp4, - const uint32_t &actual_compute_box_num) { -// T cosTheta2 = (T)cos(theta) * 0.5f; -- temp1 -// T sinTheta2 = (T)sin(theta) * 0.5f; -- temp2 -// theta is the box's 5th data: a, rotated radian; -#if __BANG_ARCH__ >= 300 - __bang_cos((float *)temp1, ((float *)box) + 4 * actual_compute_box_num, - actual_compute_box_num); - __bang_sin((float *)temp2, ((float *)box) + 4 * actual_compute_box_num, - actual_compute_box_num); -#else - __bang_taylor4_cos((T *)temp1, ((T *)box) + 4 * actual_compute_box_num, - (T *)temp3, (T *)temp4, actual_compute_box_num); - __bang_taylor4_sin((T *)temp2, ((T *)box) + 4 * actual_compute_box_num, - (T *)temp3, (T *)temp4, actual_compute_box_num); -#endif - __bang_mul_scalar((T *)temp1, (T *)temp1, (T)0.5, actual_compute_box_num); - __bang_mul_scalar((T *)temp2, (T *)temp2, (T)0.5, actual_compute_box_num); - - // Temp3 = sinTheta2 * box.h; - // Temp4 = cosTheta2 * box.w; - __bang_mul((T *)temp3, (T *)temp2, ((T *)box) + 3 * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)temp4, (T *)temp1, ((T *)box) + 2 * actual_compute_box_num, - actual_compute_box_num); - // pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w; - // pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w; - __bang_sub((T *)pts_x, (T *)box, (T *)temp3, actual_compute_box_num); - __bang_sub((T *)pts_x, (T *)pts_x, (T *)temp4, actual_compute_box_num); - __bang_add((T *)pts_x + 1 * actual_compute_box_num, (T *)box, (T *)temp3, - actual_compute_box_num); - __bang_sub((T *)pts_x + 1 * actual_compute_box_num, - (T *)pts_x + 1 * actual_compute_box_num, (T *)temp4, - actual_compute_box_num); - // Temp3 = cosTheta2 * box.h; - // Temp4 = sinTheta2 * box.w; - __bang_mul((T *)temp3, (T *)temp1, box + 3 * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)temp4, (T *)temp2, box + 2 * actual_compute_box_num, - actual_compute_box_num); - // pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; - // pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; - __bang_add((T *)pts_y, (T *)box + 1 * actual_compute_box_num, (T *)temp3, - actual_compute_box_num); - __bang_sub((T *)pts_y, (T *)pts_y, (T *)temp4, actual_compute_box_num); - __bang_sub((T *)pts_y + 1 * actual_compute_box_num, - (T *)box + 1 * actual_compute_box_num, (T *)temp3, - actual_compute_box_num); - __bang_sub((T *)pts_y + 1 * actual_compute_box_num, - (T *)pts_y + 1 * actual_compute_box_num, (T *)temp4, - actual_compute_box_num); - // pts[2].x = 2 * box.x_ctr - pts[0].x; - // pts[3].x = 2 * box.x_ctr - pts[1].x; - __bang_add((T *)pts_x + 2 * actual_compute_box_num, (T *)box, (T *)box, - actual_compute_box_num); - __bang_sub((T *)pts_x + 2 * actual_compute_box_num, - (T *)pts_x + 2 * actual_compute_box_num, (T *)pts_x, - actual_compute_box_num); - __bang_add((T *)pts_x + 3 * actual_compute_box_num, (T *)box, (T *)box, - actual_compute_box_num); - __bang_sub((T *)pts_x + 3 * actual_compute_box_num, - (T *)pts_x + 3 * actual_compute_box_num, - (T *)pts_x + 1 * actual_compute_box_num, actual_compute_box_num); - // pts[2].y = 2 * box.y_ctr - pts[0].y; - // pts[3].y = 2 * box.y_ctr - pts[1].y; - __bang_add((T *)pts_y + 2 * actual_compute_box_num, - (T *)box + 1 * actual_compute_box_num, - (T *)box + 1 * actual_compute_box_num, actual_compute_box_num); - __bang_sub((T *)pts_y + 2 * actual_compute_box_num, - (T *)pts_y + 2 * actual_compute_box_num, (T *)pts_y, - actual_compute_box_num); - __bang_add((T *)pts_y + 3 * actual_compute_box_num, - (T *)box + 1 * actual_compute_box_num, - (T *)box + 1 * actual_compute_box_num, actual_compute_box_num); - __bang_sub((T *)pts_y + 3 * actual_compute_box_num, - (T *)pts_y + 3 * actual_compute_box_num, - (T *)pts_y + 1 * actual_compute_box_num, actual_compute_box_num); -} - -template -__mlu_func__ void getIntersectPts(T *rotated_pts1_x, T *rotated_pts1_y, - T *rotated_pts2_x, T *rotated_pts2_y, - T *vec1_x, T *vec1_y, T *vec2_x, T *vec2_y, - T *intersect_pts_x, T *intersect_pts_y, - T *valid_pts, T *nums_in_ram, T *temp1_ram, - T *temp2_ram, T *temp3_ram, T *temp4_ram, - T *temp5_ram, T *temp6_ram, T *temp7_ram, - T *temp8_ram, T *temp9_ram, T *temp10_ram, - const uint32_t &actual_compute_box_num) { -// Initialize const data to ram -// temp3 = const 1e-14(@float), length = COMPUTE_COUNT_ALIGN -#if __BANG_ARCH__ >= 300 - __bang_write_value((T *)temp3_ram, COMPUTE_COUNT_ALIGN, (T)1e-14); -#else - // NOTE: Since active_reciphp function has strict value range, - // [2.2205e-16, 2e6]@float, [0.00391, 65504]@half - __bang_write_value((T *)temp3_ram, COMPUTE_COUNT_ALIGN, (float)1e-14); -#endif - // temp4 = const T(0), length = COMPUTE_COUNT_ALIGN - __bang_write_value((T *)temp4_ram, COMPUTE_COUNT_ALIGN, (T)0); - // temp5 = const T(1), length = COMPUTE_COUNT_ALIGN - __bang_write_value((T *)temp5_ram, COMPUTE_COUNT_ALIGN, (T)1); - - // Line vector, from p1 to p2 is: p1+(p2-p1)*t, t=[0,1] - // for i = 0~3, vec[i] = pts[(i+1)%4] - pts[i] - __bang_sub((T *)vec1_x, (T *)rotated_pts1_x + actual_compute_box_num, - (T *)rotated_pts1_x, 3 * actual_compute_box_num); - __bang_sub((T *)vec1_x + 3 * actual_compute_box_num, (T *)rotated_pts1_x, - (T *)rotated_pts1_x + 3 * actual_compute_box_num, - actual_compute_box_num); - __bang_sub((T *)vec1_y, (T *)rotated_pts1_y + actual_compute_box_num, - (T *)rotated_pts1_y, 3 * actual_compute_box_num); - __bang_sub((T *)vec1_y + 3 * actual_compute_box_num, (T *)rotated_pts1_y, - (T *)rotated_pts1_y + 3 * actual_compute_box_num, - actual_compute_box_num); - - __bang_sub((T *)vec2_x, (T *)rotated_pts2_x + actual_compute_box_num, - (T *)rotated_pts2_x, 3 * actual_compute_box_num); - __bang_sub((T *)vec2_x + 3 * actual_compute_box_num, (T *)rotated_pts2_x, - (T *)rotated_pts2_x + 3 * actual_compute_box_num, - actual_compute_box_num); - __bang_sub((T *)vec2_y, (T *)rotated_pts2_y + actual_compute_box_num, - (T *)rotated_pts2_y, 3 * actual_compute_box_num); - __bang_sub((T *)vec2_y + 3 * actual_compute_box_num, (T *)rotated_pts2_y, - (T *)rotated_pts2_y + 3 * actual_compute_box_num, - actual_compute_box_num); - - // First, line test - test all line combos for intersection, 4x4 possible - for (int i = 0; i < 4; i++) { - for (int j = 0; j < 4; j++) { - // T det = cross2d(vec2[j], vec1[i]) -- temp2 - cross2d((T *)temp2_ram, (T *)vec2_x + j * actual_compute_box_num, - (T *)vec2_y + j * actual_compute_box_num, - (T *)vec1_x + i * actual_compute_box_num, - (T *)vec1_y + i * actual_compute_box_num, - actual_compute_box_num, (T *)temp1_ram); - // temp8 = sign(det), since active_reciphp only receive positive values - __bang_active_sign((T *)temp8_ram, (T *)temp2_ram, - actual_compute_box_num); - // deal with parallel lines, temp2 = fabs(det), temp1 = temp2 > 1e-14 - __bang_active_abs((T *)temp2_ram, (T *)temp2_ram, actual_compute_box_num); - __bang_cycle_gt((T *)temp1_ram, (T *)temp2_ram, (T *)temp3_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - // Where temp1 = false, set recip input to 1, avoiding recip(0), cause inf - __bang_not((T *)temp9_ram, (T *)temp1_ram, actual_compute_box_num); - __bang_mul((T *)temp2_ram, (T *)temp2_ram, (T *)temp1_ram, - actual_compute_box_num); - __bang_add((T *)temp2_ram, (T *)temp2_ram, (T *)temp9_ram, - actual_compute_box_num); -// temp2 = 1/temp2, use mult (1/temp2) instead of div temp2 -#if __BANG_ARCH__ >= 300 - __bang_recip((float *)temp2_ram, (float *)temp2_ram, - actual_compute_box_num); -#else - // NOTE: active_reciphp function has strict value range: - // [2.2205e-16, 2e6]@float, [0.00391, 65504]@half - __bang_active_reciphp((T *)temp2_ram, (T *)temp2_ram, - actual_compute_box_num); -#endif - // Restore temp2 invalid box value 1 and sign-bit - __bang_mul((T *)temp2_ram, (T *)temp2_ram, (T *)temp1_ram, - actual_compute_box_num); - __bang_mul((T *)temp2_ram, (T *)temp2_ram, (T *)temp8_ram, - actual_compute_box_num); - - // auto vec12 = pts2[j] - pts1[i], (temp6, temp7) = (x, y) - __bang_sub((T *)temp6_ram, - (T *)rotated_pts2_x + j * actual_compute_box_num, - (T *)rotated_pts1_x + i * actual_compute_box_num, - actual_compute_box_num); - __bang_sub((T *)temp7_ram, - (T *)rotated_pts2_y + j * actual_compute_box_num, - (T *)rotated_pts1_y + i * actual_compute_box_num, - actual_compute_box_num); - - // T t1 = cross2d(vec2[j], vec12) mult (1/det) -- temp8 - cross2d((T *)temp8_ram, (T *)vec2_x + j * actual_compute_box_num, - (T *)vec2_y + j * actual_compute_box_num, (T *)temp6_ram, - (T *)temp7_ram, actual_compute_box_num, (T *)temp9_ram); - __bang_mul((T *)temp8_ram, (T *)temp8_ram, (T *)temp2_ram, - actual_compute_box_num); - - // temp1 &= (t1 >= 0.0f && t1 <= 1.0f) -- temp9 - __bang_cycle_ge((T *)temp9_ram, (T *)temp8_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp9_ram, - actual_compute_box_num); - __bang_cycle_le((T *)temp9_ram, (T *)temp8_ram, (T *)temp5_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp9_ram, - actual_compute_box_num); - - // T t2 = cross2d(vec1[i], vec12) mult temp2 -- temp9 - // NOTE: temp8(t1) is used after, reuse temp7(p2_y) as cross2d temp ram - cross2d((T *)temp9_ram, (T *)vec1_x + i * actual_compute_box_num, - (T *)vec1_y + i * actual_compute_box_num, (T *)temp6_ram, - (T *)temp7_ram, actual_compute_box_num, (T *)temp7_ram); - __bang_mul((T *)temp9_ram, (T *)temp9_ram, (T *)temp2_ram, - actual_compute_box_num); - - // temp1 &= (t2 >= 0.0f && t2 <= 1.0f) -- temp9 - __bang_cycle_ge((T *)temp7_ram, (T *)temp9_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp7_ram, - actual_compute_box_num); - __bang_cycle_le((T *)temp7_ram, (T *)temp9_ram, (T *)temp5_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp7_ram, - actual_compute_box_num); - - // intersections = (pts1[i] + vec1[i] * t1) * temp1 - __bang_mul((T *)temp9_ram, (T *)vec1_x + i * actual_compute_box_num, - (T *)temp8_ram, actual_compute_box_num); - __bang_add((T *)temp9_ram, - (T *)rotated_pts1_x + i * actual_compute_box_num, - (T *)temp9_ram, actual_compute_box_num); - __bang_mul((T *)intersect_pts_x + (4 * i + j) * actual_compute_box_num, - (T *)temp9_ram, (T *)temp1_ram, actual_compute_box_num); - __bang_mul((T *)temp9_ram, (T *)vec1_y + i * actual_compute_box_num, - (T *)temp8_ram, actual_compute_box_num); - __bang_add((T *)temp9_ram, - (T *)rotated_pts1_y + i * actual_compute_box_num, - (T *)temp9_ram, actual_compute_box_num); - __bang_mul((T *)intersect_pts_y + (4 * i + j) * actual_compute_box_num, - (T *)temp9_ram, (T *)temp1_ram, actual_compute_box_num); - - // Assign `valid_pts` bit and accumulate `nums_in` of valid points of each - // box pair - __bang_or((T *)valid_pts + (4 * i + j) * actual_compute_box_num, - (T *)valid_pts + (4 * i + j) * actual_compute_box_num, - (T *)temp1_ram, actual_compute_box_num); - __bang_add((T *)nums_in_ram, (T *)nums_in_ram, (T *)temp1_ram, - actual_compute_box_num); - } - } - - // Check for vertices of rect1 inside rect2 - // temp5 = ABdotAB - dot2d((T *)temp5_ram, (T *)vec2_x, (T *)vec2_y, (T *)vec2_x, (T *)vec2_y, - actual_compute_box_num, (T *)temp9_ram); - // temp6 = ADdotAD - dot2d((T *)temp6_ram, (T *)vec2_x + 3 * actual_compute_box_num, - (T *)vec2_y + 3 * actual_compute_box_num, - (T *)vec2_x + 3 * actual_compute_box_num, - (T *)vec2_y + 3 * actual_compute_box_num, actual_compute_box_num, - (T *)temp9_ram); - // assume ABCD is the rectangle, and P is the point to be judged - // P is inside ABCD iff. P's projection on AB lines within AB - // and P's projection on AD lies within AD - for (int i = 0; i < 4; i++) { - // AP = pts1[i] - pts2[0] = (temp7, temp8) - __bang_sub((T *)temp7_ram, (T *)rotated_pts1_x + i * actual_compute_box_num, - (T *)rotated_pts2_x, actual_compute_box_num); - __bang_sub((T *)temp8_ram, (T *)rotated_pts1_y + i * actual_compute_box_num, - (T *)rotated_pts2_y, actual_compute_box_num); - - // temp9 = APdotAB = dot2d(AP, AB) - dot2d((T *)temp9_ram, (T *)temp7_ram, (T *)temp8_ram, (T *)vec2_x, - (T *)vec2_y, actual_compute_box_num, (T *)temp2_ram); - // temp10 = APdotAD = -dot2d(AP, DA) - dot2d((T *)temp10_ram, (T *)temp7_ram, (T *)temp8_ram, - (T *)vec2_x + 3 * actual_compute_box_num, - (T *)vec2_y + 3 * actual_compute_box_num, actual_compute_box_num, - (T *)temp2_ram); - __bang_mul_scalar((T *)temp10_ram, (T *)temp10_ram, (T)-1, - actual_compute_box_num); - - // ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= - // ADdotAD)) - __bang_cycle_ge((T *)temp1_ram, (T *)temp9_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_cycle_ge((T *)temp2_ram, (T *)temp10_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - __bang_le((T *)temp2_ram, (T *)temp9_ram, (T *)temp5_ram, - actual_compute_box_num); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - __bang_le((T *)temp2_ram, (T *)temp10_ram, (T *)temp6_ram, - actual_compute_box_num); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - - // 16 means the 4x4 possible intersection points above - __bang_mul((T *)intersect_pts_x + (16 + i) * actual_compute_box_num, - (T *)temp1_ram, (T *)rotated_pts1_x + i * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)intersect_pts_y + (16 + i) * actual_compute_box_num, - (T *)temp1_ram, (T *)rotated_pts1_y + i * actual_compute_box_num, - actual_compute_box_num); - - // assign valid_pts bit and accumulate nums of valid points of each box pair - __bang_or((T *)valid_pts + (16 + i) * actual_compute_box_num, - (T *)valid_pts + (16 + i) * actual_compute_box_num, - (T *)temp1_ram, actual_compute_box_num); - __bang_add((T *)nums_in_ram, (T *)nums_in_ram, (T *)temp1_ram, - actual_compute_box_num); - } - - // Reverse the check - check for vertices of rect2 inside rect1 - // temp5 = ABdotAB - dot2d((T *)temp5_ram, (T *)vec1_x, (T *)vec1_y, (T *)vec1_x, (T *)vec1_y, - actual_compute_box_num, (T *)temp9_ram); - // temp6 = ADdotAD - dot2d((T *)temp6_ram, (T *)vec1_x + 3 * actual_compute_box_num, - (T *)vec1_y + 3 * actual_compute_box_num, - (T *)vec1_x + 3 * actual_compute_box_num, - (T *)vec1_y + 3 * actual_compute_box_num, actual_compute_box_num, - (T *)temp9_ram); - for (int i = 0; i < 4; i++) { - // AP = pts2[i] - pts1[0] = (temp7, temp8) - __bang_sub((T *)temp7_ram, (T *)rotated_pts2_x + i * actual_compute_box_num, - (T *)rotated_pts1_x, actual_compute_box_num); - __bang_sub((T *)temp8_ram, (T *)rotated_pts2_y + i * actual_compute_box_num, - (T *)rotated_pts1_y, actual_compute_box_num); - - // temp9 = APdotAB = dot2d(AP, AB) - dot2d((T *)temp9_ram, (T *)temp7_ram, (T *)temp8_ram, (T *)vec1_x, - (T *)vec1_y, actual_compute_box_num, (T *)temp2_ram); - // temp10 = APdotAD = -dot2d(AP, DA) - dot2d((T *)temp10_ram, (T *)temp7_ram, (T *)temp8_ram, - (T *)vec1_x + 3 * actual_compute_box_num, - (T *)vec1_y + 3 * actual_compute_box_num, actual_compute_box_num, - (T *)temp2_ram); - __bang_mul_scalar((T *)temp10_ram, (T *)temp10_ram, (T)-1, - actual_compute_box_num); - - // ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) && (APdotAD <= - // ADdotAD)) - __bang_cycle_ge((T *)temp1_ram, (T *)temp9_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_cycle_ge((T *)temp2_ram, (T *)temp10_ram, (T *)temp4_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - __bang_le((T *)temp2_ram, (T *)temp9_ram, (T *)temp5_ram, - actual_compute_box_num); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - __bang_le((T *)temp2_ram, (T *)temp10_ram, (T *)temp6_ram, - actual_compute_box_num); - __bang_and((T *)temp1_ram, (T *)temp1_ram, (T *)temp2_ram, - actual_compute_box_num); - - // 20 means the (4x4+4) possible intersection points above - __bang_mul((T *)intersect_pts_x + (20 + i) * actual_compute_box_num, - (T *)temp1_ram, (T *)rotated_pts2_x + i * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)intersect_pts_y + (20 + i) * actual_compute_box_num, - (T *)temp1_ram, (T *)rotated_pts2_y + i * actual_compute_box_num, - actual_compute_box_num); - - // assign valid_pts bit and accumulate nums of valid points of each box pair - __bang_or((T *)valid_pts + (20 + i) * actual_compute_box_num, - (T *)valid_pts + (20 + i) * actual_compute_box_num, - (T *)temp1_ram, actual_compute_box_num); - __bang_add((T *)nums_in_ram, (T *)nums_in_ram, (T *)temp1_ram, - actual_compute_box_num); - } -} - -template -__mlu_func__ void convexHullGraham( - T *intersect_pts_x, T *intersect_pts_y, T *ordered_pts_x, T *ordered_pts_y, - T *dist_ram, T *valid_box, T *valid_pts, T *nums_in_ram, T *temp1_ram, - T *temp2_ram, T *temp3_ram, T *temp_long_1, T *temp_long_2, T *temp_long_3, - const uint32_t &actual_box_num, const uint32_t &actual_compute_box_num) { - // Step1. Find the point with minimum y, if more than 1 points have the same - // minimum y, - // pick the one with the minimum x. - // set p[i].y to max_y_value if not valid_pts, to avoid invalid result - // 24 means all possible intersection points - __bang_max((T *)temp2_ram, (T *)intersect_pts_y, 24 * actual_compute_box_num); - __bang_write_value((T *)temp3_ram, COMPUTE_COUNT_ALIGN, ((T *)temp2_ram)[0]); - __bang_not((T *)temp_long_1, (T *)valid_pts, 24 * actual_compute_box_num); - __bang_cycle_mul((T *)temp_long_1, (T *)temp_long_1, (T *)temp3_ram, - 24 * actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_mul((T *)temp_long_2, (T *)intersect_pts_y, (T *)valid_pts, - 24 * actual_compute_box_num); - __bang_add((T *)temp_long_2, (T *)temp_long_2, (T *)temp_long_1, - 24 * actual_compute_box_num); - // temp2 = min_y_value(temp_long_2), use min_pool, channel=box_num, h=1, w=24 - __bang_minpool((T *)temp2_ram, (T *)temp_long_2, actual_compute_box_num, 1, - 24, 1, 24, 1, 24); - __bang_mul((T *)temp2_ram, (T *)temp2_ram, (T *)valid_box, - actual_compute_box_num); - - // set p[i].x to max_x_value if not min_y point - __bang_max((T *)temp1_ram, (T *)intersect_pts_x, 24 * actual_compute_box_num); - __bang_write_value((T *)temp3_ram, COMPUTE_COUNT_ALIGN, ((T *)temp1_ram)[0]); - __bang_cycle_eq((T *)temp_long_1, (T *)temp_long_2, (T *)temp2_ram, - 24 * actual_compute_box_num, actual_compute_box_num); - __bang_and((T *)temp_long_1, (T *)temp_long_1, (T *)valid_pts, - 24 * actual_compute_box_num); - __bang_not((T *)temp_long_3, (T *)temp_long_1, 24 * actual_compute_box_num); - __bang_cycle_mul((T *)temp_long_3, (T *)temp_long_3, (T *)temp3_ram, - 24 * actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_mul((T *)temp_long_1, (T *)intersect_pts_x, (T *)temp_long_1, - 24 * actual_compute_box_num); - __bang_add((T *)temp_long_1, (T *)temp_long_1, (T *)temp_long_3, - 24 * actual_compute_box_num); - // temp3 = min_x_value(temp_long_1), use min_pool, channel=box_num, h=1, w=24 - __bang_minpool((T *)temp3_ram, (T *)temp_long_1, actual_compute_box_num, 1, - 24, 1, 24, 1, 24); - __bang_mul((T *)temp3_ram, (T *)temp3_ram, (T *)valid_box, - actual_compute_box_num); - - // Step2. All points subtract starting-point (for sorting in the next step) - __bang_cycle_sub((T *)ordered_pts_x, (T *)intersect_pts_x, (T *)temp3_ram, - 24 * actual_compute_box_num, actual_compute_box_num); - __bang_cycle_sub((T *)ordered_pts_y, (T *)intersect_pts_y, (T *)temp2_ram, - 24 * actual_compute_box_num, actual_compute_box_num); - __bang_mul((T *)ordered_pts_x, (T *)ordered_pts_x, (T *)valid_pts, - 24 * actual_compute_box_num); - __bang_mul((T *)ordered_pts_y, (T *)ordered_pts_y, (T *)valid_pts, - 24 * actual_compute_box_num); - - // Step3. Sort every intersection point according to their relative - // cross-product values (essentially sorting according to angles) - // If the angles are the same, sort according to distance to origin - dot2d((T *)dist_ram, (T *)ordered_pts_x, (T *)ordered_pts_y, - (T *)ordered_pts_x, (T *)ordered_pts_y, 24 * actual_compute_box_num, - (T *)temp_long_3); - - T temp, temp_nums_in, temp_dist_1, temp_dist_2; - T temp1_x, temp1_y; - T temp2_x, temp2_y; - for (int i = 0; i < actual_box_num; i++) { - if (((T *)valid_box)[i]) { - // make sure all nums_in[i] points are at the front - for (int ii = 0; ii < 23; ii++) { - for (int jj = ii + 1; jj < 24; jj++) { - int ii_index = ii * actual_compute_box_num + i; - int jj_index = jj * actual_compute_box_num + i; - // ii point is not valid and jj point is valid, swap jj for ii - if ((!((T *)valid_pts)[ii_index]) && ((T *)valid_pts)[jj_index]) { - ((T *)ordered_pts_x)[ii_index] = ((T *)ordered_pts_x)[jj_index]; - ((T *)ordered_pts_y)[ii_index] = ((T *)ordered_pts_y)[jj_index]; - ((T *)dist_ram)[ii_index] = ((T *)dist_ram)[jj_index]; - ((T *)valid_pts)[ii_index] = true; - ((T *)ordered_pts_x)[jj_index] = 0; - ((T *)ordered_pts_y)[jj_index] = 0; - ((T *)dist_ram)[jj_index] = 0; - ((T *)valid_pts)[jj_index] = false; - break; - } - } - } - temp_nums_in = ((T *)nums_in_ram)[i]; - // make original q[0] = min_x, min_y before sort - for (int ii = 1; ii < temp_nums_in; ii++) { - int ii_index = ii * actual_compute_box_num + i; - if (((T *)dist_ram)[ii_index] == 0) { - // swap q[ii_index] and q[0] - ((T *)ordered_pts_x)[ii_index] = ((T *)ordered_pts_x)[i]; - ((T *)ordered_pts_y)[ii_index] = ((T *)ordered_pts_y)[i]; - ((T *)dist_ram)[ii_index] = ((T *)dist_ram)[i]; - ((T *)ordered_pts_x)[i] = 0; - ((T *)ordered_pts_y)[i] = 0; - ((T *)dist_ram)[i] = 0; - break; - } - } - for (int ii = 1; ii < temp_nums_in - 1; ii++) { - for (int jj = ii + 1; jj < temp_nums_in; jj++) { - int ii_index = ii * actual_compute_box_num + i; - int jj_index = jj * actual_compute_box_num + i; - temp1_x = ((T *)ordered_pts_x)[ii_index]; - temp1_y = ((T *)ordered_pts_y)[ii_index]; - temp2_x = ((T *)ordered_pts_x)[jj_index]; - temp2_y = ((T *)ordered_pts_y)[jj_index]; - // calculate cross product and sort q (ordered_pts) - temp = (temp1_x * temp2_y) - (temp1_y * temp2_x); - temp_dist_1 = ((T *)dist_ram)[ii_index]; - temp_dist_2 = ((T *)dist_ram)[jj_index]; - if ((temp < (T)-1e-6) || - ((fabs(temp) < (T)1e-6) && (temp_dist_1 > temp_dist_2))) { - ((T *)ordered_pts_x)[ii_index] = temp2_x; - ((T *)ordered_pts_y)[ii_index] = temp2_y; - ((T *)ordered_pts_x)[jj_index] = temp1_x; - ((T *)ordered_pts_y)[jj_index] = temp1_y; - ((T *)dist_ram)[ii_index] = temp_dist_2; - ((T *)dist_ram)[jj_index] = temp_dist_1; - } - } - } - - // Step4: - // Make sure there are at least 2 points(that don't overlap with each - // other) in the stack - int k; // index of the non-overlapped second point - for (k = 1; k < temp_nums_in; k++) { - if (((T *)dist_ram)[k * actual_compute_box_num + i] > (T)1e-8) { - break; - } - } - if (k == temp_nums_in) { - // We reach the end, which means the convex hull is just one point - // set valid_box = 0, to get ious = 0 - ((T *)valid_box)[i] = 0; - continue; - } - // q[1] = q[k]; - ((T *)ordered_pts_x)[actual_compute_box_num + i] = - ((T *)ordered_pts_x)[k * actual_compute_box_num + i]; - ((T *)ordered_pts_y)[actual_compute_box_num + i] = - ((T *)ordered_pts_y)[k * actual_compute_box_num + i]; - - // Step 5: - // Finally we can start the scanning process. - // When a non-convex relationship between the 3 points is found - // (either concave shape or duplicated points), - // we pop the previous point from the stack - // until the 3-point relationship is convex again, or - // until the stack only contains two points - int m = 2; // 2 points in the stack - for (int j = k + 1; j < temp_nums_in; j++) { - // while (m > 1 && cross2d(q[j] - q[m - 2], q[m - 1] - q[m - 2]) >= - // 0) { - // m--; - // } - temp1_x = ((T *)ordered_pts_x)[j * actual_compute_box_num + i] - - ((T *)ordered_pts_x)[(m - 2) * actual_compute_box_num + i]; - temp1_y = ((T *)ordered_pts_y)[j * actual_compute_box_num + i] - - ((T *)ordered_pts_y)[(m - 2) * actual_compute_box_num + i]; - temp2_x = ((T *)ordered_pts_x)[(m - 1) * actual_compute_box_num + i] - - ((T *)ordered_pts_x)[(m - 2) * actual_compute_box_num + i]; - temp2_y = ((T *)ordered_pts_y)[(m - 1) * actual_compute_box_num + i] - - ((T *)ordered_pts_y)[(m - 2) * actual_compute_box_num + i]; - temp = (temp1_x * temp2_y) - (temp1_y * temp2_x); - while ((m > 1) && (temp >= 0)) { - m--; - if (m > 1) { - temp1_x = - ((T *)ordered_pts_x)[j * actual_compute_box_num + i] - - ((T *)ordered_pts_x)[(m - 2) * actual_compute_box_num + i]; - temp1_y = - ((T *)ordered_pts_y)[j * actual_compute_box_num + i] - - ((T *)ordered_pts_y)[(m - 2) * actual_compute_box_num + i]; - temp2_x = - ((T *)ordered_pts_x)[(m - 1) * actual_compute_box_num + i] - - ((T *)ordered_pts_x)[(m - 2) * actual_compute_box_num + i]; - temp2_y = - ((T *)ordered_pts_y)[(m - 1) * actual_compute_box_num + i] - - ((T *)ordered_pts_y)[(m - 2) * actual_compute_box_num + i]; - temp = (temp1_x * temp2_y) - (temp1_y * temp2_x); - } - } - // q[m++] = q[j]; - ((T *)ordered_pts_x)[m * actual_compute_box_num + i] = - ((T *)ordered_pts_x)[j * actual_compute_box_num + i]; - ((T *)ordered_pts_y)[m * actual_compute_box_num + i] = - ((T *)ordered_pts_y)[j * actual_compute_box_num + i]; - m++; - } - // set last(24-m) valid_pts to false, to erase invalid q in polygon area - for (int j = m; j < temp_nums_in; j++) { - ((T *)valid_pts)[j * actual_compute_box_num + i] = 0; - } - ((T *)nums_in_ram)[i] = m; - } - } -} - -template -__mlu_func__ void polygonArea(T *ordered_pts_x, T *ordered_pts_y, T *valid_box, - T *valid_pts, T *nums_in_ram, T *temp1_ram, - T *temp2_ram, T *temp3_ram, T *temp4_ram, - T *temp5_ram, T *temp6_ram, T *temp7_ram, - T *temp8_ram, T *temp9_ram, - const uint32_t &actual_compute_box_num) { - // Set where nums_in <= 2, valid_box = false - __bang_write_value((T *)temp9_ram, COMPUTE_COUNT_ALIGN, (T)2); - __bang_cycle_gt((T *)temp1_ram, (T *)nums_in_ram, (T *)temp9_ram, - actual_compute_box_num, COMPUTE_COUNT_ALIGN); - __bang_and((T *)valid_box, (T *)valid_box, (T *)temp1_ram, - actual_compute_box_num); - - // temp1 = area, initialize with all 0 - __bang_write_zero((T *)temp1_ram, actual_compute_box_num); - __bang_max((T *)temp7_ram, (T *)nums_in_ram, actual_compute_box_num); - - // temp_nums_in = max(nums_in) - T temp_nums_in = ((T *)temp7_ram)[0]; - for (int i = 1; i < temp_nums_in - 1; i++) { - // q[i] - q[0]: (temp6, temp7) - __bang_sub((T *)temp6_ram, (T *)ordered_pts_x + i * actual_compute_box_num, - (T *)ordered_pts_x, actual_compute_box_num); - __bang_sub((T *)temp7_ram, (T *)ordered_pts_y + i * actual_compute_box_num, - (T *)ordered_pts_y, actual_compute_box_num); - __bang_mul((T *)temp6_ram, (T *)temp6_ram, - (T *)valid_pts + (i + 1) * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)temp7_ram, (T *)temp7_ram, - (T *)valid_pts + (i + 1) * actual_compute_box_num, - actual_compute_box_num); - // q[i + 1] - q[0]: (temp8, temp9) - __bang_sub((T *)temp8_ram, - (T *)ordered_pts_x + (i + 1) * actual_compute_box_num, - (T *)ordered_pts_x, actual_compute_box_num); - __bang_sub((T *)temp9_ram, - (T *)ordered_pts_y + (i + 1) * actual_compute_box_num, - (T *)ordered_pts_y, actual_compute_box_num); - __bang_mul((T *)temp8_ram, (T *)temp8_ram, - (T *)valid_pts + (i + 1) * actual_compute_box_num, - actual_compute_box_num); - __bang_mul((T *)temp9_ram, (T *)temp9_ram, - (T *)valid_pts + (i + 1) * actual_compute_box_num, - actual_compute_box_num); - // area += fabs(cross2d(q[i] - q[0], q[i + 1] - q[0])); - __bang_mul((T *)temp4_ram, (T *)temp6_ram, (T *)temp9_ram, - actual_compute_box_num); - __bang_mul((T *)temp5_ram, (T *)temp7_ram, (T *)temp8_ram, - actual_compute_box_num); - __bang_sub((T *)temp3_ram, (T *)temp4_ram, (T *)temp5_ram, - actual_compute_box_num); - __bang_active_abs((T *)temp3_ram, (T *)temp3_ram, actual_compute_box_num); - __bang_add((T *)temp1_ram, (T *)temp1_ram, (T *)temp3_ram, - actual_compute_box_num); - } - // Set where valid_box = false, intersection = 0 - __bang_mul((T *)temp1_ram, (T *)temp1_ram, (T *)valid_box, - actual_compute_box_num); - // area = area / 2.0 - __bang_mul_scalar((T *)temp1_ram, (T *)temp1_ram, (T)0.5, - actual_compute_box_num); -} - -#endif // IOU3D_UTILS_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu deleted file mode 100644 index 40ad6396a6..0000000000 --- a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu +++ /dev/null @@ -1,2094 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 by Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ - -#include "common_mlu_helper.hpp" -#include - -/**************************************************************************************** - * - * NRAM partition forward: - * | spatial_shapes | data_value_p1_ping | data_value_p2_ping | - * | data_value_p3_ping | data_value_p4_ping | data_col_ping | - * | data_value_p1_pong | data_value_p2_pong | data_value_p3_pong | - * | data_value_p4_pong | data_col_pong | auxiliary_a | - * | auxiliary_b | - * | 128bytes | deal_size | deal_size | - * | deal_size | deal_size | deal_size | - * | deal_size | deal_size | deal_size | - * | deal_size | deal_size | deal_size | - * | deal_size | - * - ****************************************************************************************/ - -/**************************************************************************************** - * - * NRAM partition backward: - * default kernel - * | grad_output_nram | grad_output_nram_temp | grad_weight | - * | grad_h_weight | grad_w_weight | top_grad | - * | top_grad_temp | spatial_shapes_nram | sampling_loc_nram | - * | deal_size | deal_size | deal_size | - * | deal_size | deal_size | deal_size | - * | deal_size | deal_size | 64bytes | - * - * small channel kernel - * | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl | - * | nram_grad_output_br | grad_temp1 | grad_temp2 | - * | grad_temp3 | grad_temp4 | nram_loc_w | - * | nram_loc_h | nram_h_low | nram_w_low | - * | nram_h_high | nram_w_high | nram_h_low_temp | - * | nram_h_high_temp | nram_hw | nram_hh | - * | nram_lw | nram_lh | nram_h_low_ptr_offset | - * | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset | - * | nram_w1 | nram_w2 | nram_w3 | - * | nram_w4 | nram_grad_weight | nram_base_ptr | - * | nram_offset_temp | nram_offset1 | nram_offset2 | - * | nram_offset3 | nram_offset4 | nram_w_low_temp | - * | nram_spatial_shapes | nram_level_start_index | nram_h_stride | - ****************************************************************************************/ - -#define TWELVE_SPLIT 12 -#define ALIGN_NUM 32 -#define ALIGN_NUM_FOR_REDUCE 32 -#define ELE_COUNT 32 -#define LEN_FLOAT sizeof(float) - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; - -template -__mlu_func__ void loadNeighborPointsData( - const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram, - T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num, - const int32_t &width, const int32_t &height, const int32_t &num_heads, - const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) { - const int32_t w_low = floorf(x); - const int32_t h_low = floorf(y); - const int32_t w_high = w_low + 1; - const int32_t h_high = h_low + 1; - - const int32_t w_stride = num_heads * channels; - const int32_t h_stride = width * w_stride; - const int32_t h_low_ptr_offset = h_low * h_stride; - const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int32_t w_low_ptr_offset = w_low * w_stride; - const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride; - const int32_t base_ptr_offset = head_idx * channels; - - // top-left point - if (h_low >= 0 && w_low >= 0) { - const int32_t v1_offset = - h_low_ptr_offset + w_low_ptr_offset + base_ptr_offset; - __memcpy_async(data_value_p1_nram, data_value_gdram + v1_offset, - deal_num * sizeof(T), GDRAM2NRAM); - } - - // top-right point - if (h_low >= 0 && w_high <= width - 1) { - const int32_t v2_offset = - h_low_ptr_offset + w_high_ptr_offset + base_ptr_offset; - __memcpy_async(data_value_p2_nram, data_value_gdram + v2_offset, - deal_num * sizeof(T), GDRAM2NRAM); - } - - // bottom-left point - if (h_high <= height - 1 && w_low >= 0) { - const int32_t v3_offset = - h_high_ptr_offset + w_low_ptr_offset + base_ptr_offset; - __memcpy_async(data_value_p3_nram, data_value_gdram + v3_offset, - deal_num * sizeof(T), GDRAM2NRAM); - } - - // bottom-right point - if (h_high <= height - 1 && w_high <= width - 1) { - const int32_t v4_offset = - h_high_ptr_offset + w_high_ptr_offset + base_ptr_offset; - __memcpy_async(data_value_p4_nram, data_value_gdram + v4_offset, - deal_num * sizeof(T), GDRAM2NRAM); - } -} - -template -__mlu_func__ void computeMsDeformAttn( - T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram, - T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b, - T *data_col_nram, const T &weight, const size_t &deal_num, - const int32_t &width, const int32_t &height, const T &x, const T &y) { - const int32_t w_low = floorf(x); - const int32_t h_low = floorf(y); - const int32_t w_high = w_low + 1; - const int32_t h_high = h_low + 1; - - const T lw = x - w_low; - const T lh = y - h_low; - const T hw = 1 - lw; - const T hh = 1 - lh; - const T w1 = hh * hw; - const T w2 = hh * lw; - const T w3 = lh * hw; - const T w4 = lh * lw; - - __bang_write_value((T *)sample_point_value, deal_num, (T)0); - - // top-left point - if (h_low >= 0 && w_low >= 0) { - // sample_point_value += v1 * w1 - __bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p1_nram, (T)w1, - deal_num); - __bang_add((T *)sample_point_value, (T *)sample_point_value, - (T *)auxiliary_b, deal_num); - } - - // top-right point - if (h_low >= 0 && w_high <= width - 1) { - // sample_point_value += v2 * w2 - __bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p2_nram, (T)w2, - deal_num); - __bang_add((T *)sample_point_value, (T *)sample_point_value, - (T *)auxiliary_b, deal_num); - } - - // bottom-left point - if (h_high <= height - 1 && w_low >= 0) { - // sample_point_value += v3 * w3 - __bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p3_nram, (T)w3, - deal_num); - __bang_add((T *)sample_point_value, (T *)sample_point_value, - (T *)auxiliary_b, deal_num); - } - - // bottom-right point - if (h_high <= height - 1 && w_high <= width - 1) { - // sample_point_value += v4 * w4 - __bang_mul_scalar((T *)auxiliary_b, (T *)data_value_p4_nram, (T)w4, - deal_num); - __bang_add((T *)sample_point_value, (T *)sample_point_value, - (T *)auxiliary_b, deal_num); - } - - __bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight, - deal_num); - __bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value, - deal_num); -} - -template -__mlu_global__ void MLUKernelMsDeformAttnForwardDefault( - const char *data_value_gdram, const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram) { - if (coreId == 0x80) { - return; - } - - const size_t spatial_size = PAD_UP(2 * sizeof(int32_t), NFU_ALIGN_SIZE); - const size_t span_num_deal = - PAD_DOWN((MAX_NRAM_SIZE - spatial_size) / TWELVE_SPLIT / sizeof(T), - NFU_ALIGN_SIZE); - const size_t align_num = NFU_ALIGN_SIZE; - const int32_t channels_seg_num = channels / span_num_deal; - const size_t channels_rem = channels % span_num_deal; - const size_t channels_align_rem = CEIL_ALIGN(channels_rem, align_num); - char *data_spatial_shapes_nram = nram_buffer; - char *ping_data_value_p1_nram = data_spatial_shapes_nram + spatial_size; - char *ping_data_value_p2_nram = - ping_data_value_p1_nram + span_num_deal * sizeof(T); - char *ping_data_value_p3_nram = - ping_data_value_p2_nram + span_num_deal * sizeof(T); - char *ping_data_value_p4_nram = - ping_data_value_p3_nram + span_num_deal * sizeof(T); - char *ping_data_col_nram = - ping_data_value_p4_nram + span_num_deal * sizeof(T); - char *pong_data_value_p1_nram = - ping_data_col_nram + span_num_deal * sizeof(T); - char *pong_data_value_p2_nram = - pong_data_value_p1_nram + span_num_deal * sizeof(T); - char *pong_data_value_p3_nram = - pong_data_value_p2_nram + span_num_deal * sizeof(T); - char *pong_data_value_p4_nram = - pong_data_value_p3_nram + span_num_deal * sizeof(T); - char *pong_data_col_nram = - pong_data_value_p4_nram + span_num_deal * sizeof(T); - char *auxiliary_a = pong_data_col_nram + span_num_deal * sizeof(T); - char *auxiliary_b = auxiliary_a + span_num_deal * sizeof(T); - const size_t ping_pong_gap = 5 * span_num_deal * sizeof(T); - size_t data_col_ping_pong_idx = 0; - - int32_t block_num_per_core = (batch_size * num_queries * num_heads) / taskDim; - const int32_t block_num_rem = - (batch_size * num_queries * num_heads) % taskDim; - const int32_t idx_start = taskId < (block_num_rem + 1) - ? taskId * (block_num_per_core + 1) - : taskId * block_num_per_core + block_num_rem; - block_num_per_core = - taskId < block_num_rem - ? (batch_size * num_queries * num_heads) / taskDim + 1 - : (batch_size * num_queries * num_heads) / taskDim; - - for (int32_t cur_idx = idx_start; cur_idx < idx_start + block_num_per_core; - ++cur_idx) { - // cur_idx = batch_idx * num_queries * num_heads + query_idx * num_heads + - // head_idx - const int32_t head_idx = cur_idx % num_heads; - const int32_t batch_idx = (cur_idx / num_heads) / num_queries; - - const char *data_value_gdram_start = - data_value_gdram + - batch_idx * num_keys * num_heads * channels * sizeof(T); - const char *data_sampling_loc_gdram_start = - data_sampling_loc_gdram + - cur_idx * num_levels * num_points * 2 * sizeof(T); - const char *data_attn_weight_gdram_start = - data_attn_weight_gdram + cur_idx * num_levels * num_points * sizeof(T); - char *data_col_gdram_start = - data_col_gdram + cur_idx * channels * sizeof(T); - - for (int32_t c_seg_idx = 0; c_seg_idx < channels_seg_num; ++c_seg_idx) { - __bang_write_value( - (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), - span_num_deal, (T)0); - // load data - // level_idx = 0, point_idx = 0 - __memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram, - 2 * sizeof(int32_t), GDRAM2NRAM); - int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; - int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; - const char *data_value_ptr = - data_value_gdram_start + c_seg_idx * span_num_deal * sizeof(T); - T loc_w = ((T *)data_sampling_loc_gdram_start)[0]; - T loc_h = ((T *)data_sampling_loc_gdram_start)[1]; - T weight = ((T *)data_attn_weight_gdram_start)[0]; - T x = loc_w * spatial_w - 0.5; - T y = loc_h * spatial_h - 0.5; - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, (T *)ping_data_value_p1_nram, - (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, - (T *)ping_data_value_p4_nram, span_num_deal, spatial_w, spatial_h, - num_heads, channels, x, y, head_idx); - } - T spatial_h_next_point = 0; - T spatial_w_next_point = 0; - T weight_next_point = 0; - T x_next_point = 0; - T y_next_point = 0; - __asm__ volatile("sync;"); - - for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { - for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { - // load data - if (point_idx == num_points - 1 && level_idx == num_levels - 1) { - // last point no need to load data, continue to compute - } else if (point_idx == num_points - 1) { - const int32_t level_start_id = - ((int32_t *)data_level_start_index_gdram)[level_idx + 1]; - const int32_t spatial_h_ptr = (level_idx + 1) << 1; - __memcpy( - data_spatial_shapes_nram, - data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t), - 2 * sizeof(int32_t), GDRAM2NRAM); - spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0]; - spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1]; - data_value_ptr = data_value_gdram_start + - (level_start_id * num_heads * channels + - c_seg_idx * span_num_deal) * - sizeof(T); - loc_w = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2]; - loc_h = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2 + 1]; - weight_next_point = - ((T *)data_attn_weight_gdram_start)[level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w_next_point - 0.5; - y_next_point = loc_h * spatial_h_next_point - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h_next_point && - x_next_point < spatial_w_next_point) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - span_num_deal, spatial_w_next_point, spatial_h_next_point, - num_heads, channels, x_next_point, y_next_point, head_idx); - } - } else { - spatial_h_next_point = spatial_h; - spatial_w_next_point = spatial_w; - loc_w = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2]; - loc_h = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2 + 1]; - weight_next_point = - ((T *)data_attn_weight_gdram_start)[level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w - 0.5; - y_next_point = loc_h * spatial_h - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h && x_next_point < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - span_num_deal, spatial_w, spatial_h, num_heads, channels, - x_next_point, y_next_point, head_idx); - } - } - - // compute - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - computeMsDeformAttn( - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - weight, span_num_deal, spatial_w, spatial_h, x, y); - } - - spatial_w = spatial_w_next_point; - spatial_h = spatial_h_next_point; - weight = weight_next_point; - x = x_next_point; - y = y_next_point; - __asm__ volatile("sync;"); - } - } - // store - __memcpy_async( - data_col_gdram_start + c_seg_idx * span_num_deal * sizeof(T), - ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, - span_num_deal * sizeof(T), NRAM2GDRAM); - data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; - } - - if (channels_rem > 0) { - __bang_write_value( - (T *)(ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap), - channels_align_rem, (T)0); - // load data - // level_idx = 0, point_idx = 0 - __memcpy(data_spatial_shapes_nram, data_spatial_shapes_gdram, - 2 * sizeof(int32_t), GDRAM2NRAM); - int32_t spatial_h = ((int32_t *)data_spatial_shapes_nram)[0]; - int32_t spatial_w = ((int32_t *)data_spatial_shapes_nram)[1]; - const char *data_value_ptr = - data_value_gdram_start + channels_seg_num * span_num_deal * sizeof(T); - T loc_w = ((T *)data_sampling_loc_gdram_start)[0]; - T loc_h = ((T *)data_sampling_loc_gdram_start)[1]; - T weight = ((T *)data_attn_weight_gdram_start)[0]; - T x = loc_w * spatial_w - 0.5; - T y = loc_h * spatial_h - 0.5; - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, (T *)ping_data_value_p1_nram, - (T *)ping_data_value_p2_nram, (T *)ping_data_value_p3_nram, - (T *)ping_data_value_p4_nram, channels_rem, spatial_w, spatial_h, - num_heads, channels, x, y, head_idx); - } - T spatial_h_next_point = 0; - T spatial_w_next_point = 0; - T weight_next_point = 0; - T x_next_point = 0; - T y_next_point = 0; - __asm__ volatile("sync;"); - - for (int32_t level_idx = 0; level_idx < num_levels; ++level_idx) { - for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { - // load data - if (point_idx == num_points - 1 && level_idx == num_levels - 1) { - // last point no need to load data, continue to compute - } else if (point_idx == num_points - 1) { - const int32_t level_start_id = - ((int32_t *)data_level_start_index_gdram)[level_idx + 1]; - const int32_t spatial_h_ptr = (level_idx + 1) << 1; - __memcpy( - data_spatial_shapes_nram, - data_spatial_shapes_gdram + spatial_h_ptr * sizeof(int32_t), - 2 * sizeof(int32_t), GDRAM2NRAM); - spatial_h_next_point = ((int32_t *)data_spatial_shapes_nram)[0]; - spatial_w_next_point = ((int32_t *)data_spatial_shapes_nram)[1]; - data_value_ptr = data_value_gdram_start + - (level_start_id * num_heads * channels + - channels_seg_num * span_num_deal) * - sizeof(T); - loc_w = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2]; - loc_h = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2 + 1]; - weight_next_point = - ((T *)data_attn_weight_gdram_start)[level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w_next_point - 0.5; - y_next_point = loc_h * spatial_h_next_point - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h_next_point && - x_next_point < spatial_w_next_point) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - channels_rem, spatial_w_next_point, spatial_h_next_point, - num_heads, channels, x_next_point, y_next_point, head_idx); - } - } else { - spatial_w_next_point = spatial_w; - spatial_h_next_point = spatial_h; - loc_w = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2]; - loc_h = ((T *)data_sampling_loc_gdram_start) - [(level_idx * num_points + point_idx + 1) * 2 + 1]; - weight_next_point = - ((T *)data_attn_weight_gdram_start)[level_idx * num_points + - point_idx + 1]; - x_next_point = loc_w * spatial_w - 0.5; - y_next_point = loc_h * spatial_h - 0.5; - if (y_next_point > -1 && x_next_point > -1 && - y_next_point < spatial_h && x_next_point < spatial_w) { - loadNeighborPointsData( - (T *)data_value_ptr, - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx + 1) % 2) * - ping_pong_gap), - channels_rem, spatial_w, spatial_h, num_heads, channels, - x_next_point, y_next_point, head_idx); - } - } - - // compute - if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - computeMsDeformAttn( - (T *)(ping_data_value_p1_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p2_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p3_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)(ping_data_value_p4_nram + - ((level_idx * num_points + point_idx) % 2) * - ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - weight, channels_align_rem, spatial_w, spatial_h, x, y); - } - - spatial_w = spatial_w_next_point; - spatial_h = spatial_h_next_point; - weight = weight_next_point; - x = x_next_point; - y = y_next_point; - __asm__ volatile("sync;"); - } - } - // store - __memcpy_async( - data_col_gdram_start + channels_seg_num * span_num_deal * sizeof(T), - ping_data_col_nram + data_col_ping_pong_idx * ping_pong_gap, - channels_rem * sizeof(T), NRAM2GDRAM); - data_col_ping_pong_idx = (data_col_ping_pong_idx + 1) % 2; - } - } - __asm__ volatile("sync;"); - return; -} - -__mlu_func__ void genMask0101(float *mask_ram, int32_t size) { - int32_t align_num = NFU_ALIGN_SIZE / sizeof(float); - for (int32_t i = 0; i < align_num; ++i) { - mask_ram[i] = i % 2; - } - __asm__ volatile("sync;"); - __memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM, - NFU_ALIGN_SIZE, 0, size / align_num - 2); - __asm__ volatile("sync;"); -} - -template -__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( - const char *data_value_gdram, const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram) { -#if __BANG_ARCH__ >= 300 - if (coreId == 0x80) { - return; - } - - size_t block_num_per_core, batch_start, deal_g, offset_g; - size_t block_num_rem = 0; - const size_t grid_total = num_queries * num_heads * num_levels * num_points; - if (batch_size >= taskDim) { - block_num_rem = batch_size % taskDim; - block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1 - : batch_size / taskDim; - batch_start = taskId < block_num_rem - ? taskId * block_num_per_core - : taskId * block_num_per_core + block_num_rem; - deal_g = grid_total; - offset_g = 0; - } else { - size_t skip_n = taskDim / batch_size; - batch_start = taskId / skip_n; - block_num_per_core = batch_start >= batch_size ? 0 : 1; - deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points); - size_t id = taskId % skip_n; - offset_g = id * deal_g; - deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1); - } - - const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float); - int32_t deal_num; - int32_t cut_channel_iter = 2; - - const size_t spatial_size = - PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE); - const size_t level_start_index_size = - PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); - - int32_t channel = channels; - int32_t mult; - while (true) { - deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) / - (8 * channel + 7) / sizeof(T); - deal_num = PAD_DOWN(deal_num, float_align); - deal_num = PAD_DOWN(deal_num, num_levels * num_points); - if (deal_num > 0) { - break; - } else { - channel = channels / cut_channel_iter; - cut_channel_iter += 2; - } - } - mult = channel; - - const int32_t c_rep = channels / channel; - const int32_t c_rem = channels % channel; - - const int32_t g_rep = deal_g / deal_num; - const int32_t g_rem = deal_g % deal_num; - - // nram buffer alloc - char *data_spatial_shapes_nram = nram_buffer; - char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size; - char *input_tl = data_level_start_index_nram + level_start_index_size; - char *input_tr = input_tl + deal_num * mult * sizeof(T); - char *input_bl = input_tr + deal_num * mult * sizeof(T); - char *input_br = input_bl + deal_num * mult * sizeof(T); - char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T); - char *weight_tr = weight_tl + deal_num * mult * sizeof(T); - char *weight_bl = weight_tr + deal_num * mult * sizeof(T); - char *weight_br = weight_bl + deal_num * mult * sizeof(T); - char *mask_tl = weight_br + deal_num * mult * sizeof(T); - char *mask_tr = mask_tl + deal_num * sizeof(T); - char *mask_bl = mask_tr + deal_num * sizeof(T); - char *mask_br = mask_bl + deal_num * sizeof(T); - char *point_ram = mask_br + deal_num * sizeof(T); - char *index_tl = point_ram + deal_num * sizeof(T); - char *index_bl = index_tl + deal_num * sizeof(T); - - // nram space reuse - char *grid_ram = weight_tl; - char *mask_ram = weight_bl; - char *coord_x = input_bl; - char *coord_y = coord_x + deal_num * sizeof(T); - char *coord_x_low = input_tl; - char *coord_y_low = coord_x_low + deal_num * sizeof(T); - char *coord_x_low_int = weight_tl; - char *coord_y_low_int = weight_tr; - char *spatial_x = mask_tl; - char *spatial_y = mask_tr; - char *spatial_x_float = weight_bl; - char *spatial_y_float = weight_br; - char *spatial_x_temp = mask_bl; - char *spatial_y_temp = mask_br; - char *base_ptr_offset = weight_tl; - char *auxiliary_a = point_ram; - char *auxiliary_b = weight_bl; - - __memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram, - num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); - __memcpy_async(data_level_start_index_nram, data_level_start_index_gdram, - num_levels * sizeof(int32_t), GDRAM2NRAM); - __asm__ volatile("sync;"); - - for (int32_t batch_idx = batch_start; - batch_idx < batch_start + block_num_per_core; ++batch_idx) { - for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) { - int32_t io_data_num = deal_num; - const int32_t grid_off_base = - batch_idx * grid_total + offset_g + grid_iter * deal_num; - if (grid_iter == g_rep) { - if (g_rem == 0) { - continue; - } else { - io_data_num = g_rem; - } - } - - char *data_col_gdram_start = - data_col_gdram + (batch_idx * num_queries * num_heads * channels + - (offset_g + grid_iter * deal_num) / - (num_levels * num_points) * channels) * - sizeof(float); - - // load data_sampling_loc - __memcpy_async( - grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float), - io_data_num * 2 * sizeof(float), GDRAM2NRAM); - genMask0101((float *)mask_ram, deal_num * 2); - __asm__ volatile("sync;"); - - // generate x and y coordinate vector - // generate spatial_x and spatial_y spatial vector - __bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram, - deal_num * 2); // y - __bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram, - (float *)mask_ram, - num_levels * 2); // spatial_x - __bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2); - __bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram, - deal_num * 2); // x - __bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram, - (float *)mask_ram, - num_levels * 2); // spatial_y - - for (int32_t i = 0; i < num_levels; i++) { - __bang_write_value((int32_t *)spatial_x + i * num_points, num_points, - ((int32_t *)spatial_x_temp)[i]); - __bang_write_value((int32_t *)spatial_y + i * num_points, num_points, - ((int32_t *)spatial_y_temp)[i]); - } - - __bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x, - num_levels * num_points, 0); - __bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y, - num_levels * num_points, 0); - - // map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0, - // spatial_y] - __bang_cycle_mul((float *)coord_x, (float *)coord_x, - (float *)spatial_x_float, deal_num, - num_levels * num_points); - __bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5, - deal_num); - __bang_cycle_mul((float *)coord_y, (float *)coord_y, - (float *)spatial_y_float, deal_num, - num_levels * num_points); - __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5, - deal_num); - - __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num); - __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num); - - // calc index_tl - const int32_t w_stride = num_heads * channels; - __bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low, - deal_num, 0); - __bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low, - deal_num, 0); - __bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int, - (int32_t *)spatial_x, deal_num, num_levels * num_points); - __bang_add((int32_t *)index_tl, (int32_t *)index_tl, - (int32_t *)coord_x_low_int, deal_num); - __bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride, - deal_num); - - const int32_t deal_lp_num = deal_num / (num_levels * num_points); - const int32_t h_rep = deal_lp_num / num_heads; - const int32_t h_rem = deal_lp_num % num_heads; - const int32_t head_start = - ((offset_g + grid_iter * deal_num) / (num_levels * num_points)) % - num_heads; - for (int32_t iter = 0; iter < num_heads; ++iter) { - ((int32_t *)base_ptr_offset)[iter] = - ((head_start + iter) % num_heads) * channels; - } - if (h_rep > 0) { - __memcpy((int32_t *)base_ptr_offset + num_heads, - (int32_t *)base_ptr_offset, num_heads * sizeof(int32_t), - NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1); - } - if (h_rep > 0 && h_rem > 0) { - __memcpy((int32_t *)base_ptr_offset + h_rep * num_heads, - (int32_t *)base_ptr_offset, h_rem * sizeof(int32_t), - NRAM2NRAM); - } - __bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num, - num_levels * num_points); - __bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, - (int32_t *)base_ptr_offset, deal_num, deal_lp_num); - __bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a, - num_levels * num_points, deal_lp_num); - - // calc index_bl - __bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride, - deal_num); - __bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl, - (int32_t *)auxiliary_a, deal_num, - num_levels * num_points); - - // calc mask_tl, mask_tr, mask_bl, mask_br - __bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float, - (float)1.0, deal_num); - __bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float, - (float)1.0, deal_num); - // mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y - __bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0, - deal_num); - __bang_cycle_le((float *)mask_br, (float *)coord_x_low, - (float *)spatial_x_float, deal_num, - num_levels * num_points); - __bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br, - deal_num); - - __bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0, - deal_num); - __bang_cycle_le((float *)mask_br, (float *)coord_y_low, - (float *)spatial_y_float, deal_num, - num_levels * num_points); - __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, - deal_num); - __bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl, - deal_num); - - // mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y - __bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0), - deal_num); - __bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low, - (float *)spatial_x_float, deal_num, - num_levels * num_points); - __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, - deal_num); - __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, - deal_num); - - // mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y - __bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low, - (float)(-1.0), deal_num); - __bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low, - (float *)spatial_y_float, deal_num, - num_levels * num_points); - __bang_and((float *)auxiliary_a, (float *)auxiliary_a, - (float *)auxiliary_b, deal_num); - __bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a, - deal_num); - - // mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < - // spatial_y - __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, - deal_num); - - // calc inner point num - __bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0, - deal_num); - __bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0, - deal_num); - __bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr, - deal_num); - __bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0, - deal_num); - __bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br, - deal_num); - __bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl, - deal_num); - - // calc interpolation weight - __bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x, - deal_num); - __bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y, - deal_num); - __bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0, - deal_num); - __bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0, - deal_num); - - __bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low, - deal_num); - __bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low, - deal_num); - __bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br, - deal_num); - __bang_mul((float *)input_tl + deal_num, (float *)weight_br, - (float *)weight_tl, deal_num); - __bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl, - (float *)weight_tr, deal_num); - __bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl, - (float *)weight_tr, deal_num); - - __asm__ volatile("sync;"); - - // extend weight - const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT; - const int32_t w_rem = channel % ELE_COUNT; - if (w_rem != 0) { - const int32_t data_sz = 1 * sizeof(float); - const int32_t dst_str = channel * sizeof(float); - for (int32_t iter = w_rep; iter < channel; ++iter) { - __memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz, - NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1); - } - } - if (w_rep != 0) { - for (int32_t i = 0; i < 4 * deal_num; i++) { - __bang_write_value((float *)weight_tl + i * channel, w_rep, - ((float *)input_tl)[i]); - } - } - - __asm__ volatile("sync;"); - - const char *data_value_gdram_start = - data_value_gdram + - batch_idx * num_keys * num_heads * channels * sizeof(float); - const int32_t c_str = deal_num * channel * sizeof(float); - const int32_t cs_str = num_heads * channels * sizeof(float); - - for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) { - int32_t c_real_num = channel; - if (c_iter == c_rep) { - if (c_rem == 0) { - continue; - } else { - c_real_num = c_rem; - } - } - - __bang_write_zero((float *)input_tl, 4 * deal_num * channel); - __asm__ volatile("sync;"); - - // load data_value - for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) { - const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx]; - const int32_t tl_offset = ((int32_t *)index_tl)[p_idx]; - const int32_t bl_offset = ((int32_t *)index_bl)[p_idx]; - const int32_t level_start_id = - ((int32_t *)data_level_start_index_nram)[(p_idx / num_points) % - num_levels]; - const char *data_value_ptr = - data_value_gdram_start + - (level_start_id * num_heads * channels + c_iter * channel) * - sizeof(float); - - switch (inner_point_num) { - case 16: // 4 points are cached. - __memcpy_async((float *)input_tl + p_idx * channel, - (float *)data_value_ptr + tl_offset, - c_real_num * sizeof(float), GDRAM2NRAM, c_str, - cs_str, 1); - __memcpy_async((float *)input_bl + p_idx * channel, - (float *)data_value_ptr + bl_offset, - c_real_num * sizeof(float), GDRAM2NRAM, c_str, - cs_str, 1); - break; - case 12: // 2 points are cached. (top_left, top_right) - __memcpy_async((float *)input_tl + p_idx * channel, - (float *)data_value_ptr + tl_offset, - c_real_num * sizeof(float), GDRAM2NRAM, c_str, - cs_str, 1); - break; - case 4: // 2 points are cached. (bottom_left, bottom_right) - __memcpy_async((float *)input_bl + p_idx * channel, - (float *)data_value_ptr + bl_offset, - c_real_num * sizeof(float), GDRAM2NRAM, c_str, - cs_str, 1); - break; - case 10: // 2 points are cached. (top_left, bottom_left) - __memcpy_async((float *)input_tl + p_idx * channel, - (float *)data_value_ptr + tl_offset, - c_real_num * sizeof(float), GDRAM2NRAM); - __memcpy_async((float *)input_bl + p_idx * channel, - (float *)data_value_ptr + bl_offset, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - case 6: // 2 points are cached. (top_right, bottom_right) - __memcpy_async( - (float *)input_tr + p_idx * channel, - (float *)data_value_ptr + tl_offset + num_heads * channels, - c_real_num * sizeof(float), GDRAM2NRAM); - __memcpy_async( - (float *)input_br + p_idx * channel, - (float *)data_value_ptr + bl_offset + num_heads * channels, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - case 7: // 1 point is cached. (top_left) - __memcpy_async((float *)input_tl + p_idx * channel, - (float *)data_value_ptr + tl_offset, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - case 5: // 1 point is cached. (top_right) - __memcpy_async( - (float *)input_tr + p_idx * channel, - (float *)data_value_ptr + tl_offset + num_heads * channels, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - case 3: // 1 point is cached. (bottom_left) - __memcpy_async((float *)input_bl + p_idx * channel, - (float *)data_value_ptr + bl_offset, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - case 1: // 1 point is cached. (bottom_right) - __memcpy_async( - (float *)input_br + p_idx * channel, - (float *)data_value_ptr + bl_offset + num_heads * channels, - c_real_num * sizeof(float), GDRAM2NRAM); - break; - default: - continue; - } - } - - __asm__ volatile("sync;"); - - // interpolation - __bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl, - 4 * deal_num * channel); - __bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl, - 2 * deal_num * channel); - __bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr, - deal_num * channel); - - // load attention weight - void *attn_weight = mask_tl; - __memcpy((float *)attn_weight, - (float *)data_attn_weight_gdram + grid_off_base, - io_data_num * sizeof(float), GDRAM2NRAM); - - // calc data_col, muladd attention weight - __bang_transpose((float *)input_tr, (float *)input_tl, deal_num, - channel); - __bang_cycle_mul((float *)input_tr, (float *)input_tr, - (float *)attn_weight, deal_num * channel, deal_num); - __bang_transpose((float *)input_tl, (float *)input_tr, channel, - deal_num); - __bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1, - io_data_num, 1, num_levels * num_points, - num_levels * num_points, 1); - - // store - __memcpy((float *)data_col_gdram_start + c_iter * channel, - (float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM, - channels * sizeof(float), channel * sizeof(float), - (io_data_num / (num_levels * num_points)) - 1); - } - } - } - __asm__ volatile("sync;"); -#endif - return; -} - -template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( - const char *data_value_gdram, const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram); - -void KernelMsDeformAttnForwardDefault( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char *data_value_gdram, - const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram) { - MLUKernelMsDeformAttnForwardDefault<<>>( - data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, - data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); -} - -template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( - const char *data_value_gdram, const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram); - -void KernelMsDeformAttnForwardSmallChannel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char *data_value_gdram, - const char *data_spatial_shapes_gdram, - const char *data_level_start_index_gdram, - const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char *data_col_gdram) { - MLUKernelMsDeformAttnForwardSmallChannel<<>>( - data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, - data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); -} - -template -void __mlu_func__ msDeformAttnCol2imBilinear( - T *top_grad_temp, const int32_t &height, const int32_t &width, const T &w1, - const T &w2, const T &w3, const T &w4, const int32_t &h_low, - const int32_t &w_low, const int32_t &h_high, const int32_t &w_high, - const int32_t &base_ptr, const int32_t &h_low_ptr_offset, - const int32_t &w_low_ptr_offset, const int32_t &h_high_ptr_offset, - const int32_t &w_high_ptr_offset, const T &hh, const T &hw, const T &lh, - const T &lw, T *top_grad, const T &data_attn_weight, T *grad_h_weight, - T *grad_w_weight, T *grad_value, T *grad_output_nram, T *grad_weight, - T *grad_sampling_loc, T *grad_attn_weight, T *grad_output_nram_temp, - const int32_t &deal_num, const int32_t &deal_num_real, - const T *data_value_ptr) { - if (h_low >= 0 && w_low >= 0) { - int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; - __memcpy(grad_output_nram, data_value_ptr + offset1, - deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); - __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real); - __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1), - (T *)top_grad_temp, deal_num_real); - } - if (h_low >= 0 && w_high <= width - 1) { - int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; - __memcpy(grad_output_nram_temp, data_value_ptr + offset2, - deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real); - - __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2, - deal_num_real); - __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num_real); - __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2), - (T *)top_grad_temp, deal_num_real); - } - if (h_high <= height - 1 && w_low >= 0) { - int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; - __memcpy(grad_output_nram_temp, data_value_ptr + offset3, - deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3, - deal_num_real); - __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num_real); - __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3), - (T *)top_grad_temp, deal_num_real); - } - if (h_high <= height - 1 && w_high <= width - 1) { - int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; - __memcpy(grad_output_nram_temp, data_value_ptr + offset4, - deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real); - // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4, - deal_num_real); - __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num_real); - - __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4), - (T *)top_grad_temp, deal_num_real); - } - __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real); -#if __BANG_ARCH__ >= 322 - recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); -#else - const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT; - recursiveSumPool(grad_output_nram, align_num_on_200, - deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_output_nram, grad_output_nram, - NFU_ALIGN_SIZE / LEN_FLOAT); -#endif - __bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight, - (T *)grad_output_nram, 1); - __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); - __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real); -#if __BANG_ARCH__ >= 322 - recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); -#else - recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200, - ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT); -#endif - __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), - (T *)grad_w_weight, 1); - - __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real); - __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real); -#if __BANG_ARCH__ >= 322 - recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); -#else - recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200, - ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT); -#endif - __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), - (T *)grad_h_weight, 1); -} - -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( - const float *data_value, const int32_t *spatial_shapes, - const int32_t *data_level_start_index, const float *data_sampling_loc, - const float *data_attn_weight, const float *grad_output, - const int32_t batch, const int32_t spatial_size, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_query, - const int32_t num_points, float *grad_value, float *grad_sampling_loc, - float *grad_attn_weight) { - if (coreId == 0x80) { - return; - } - const int32_t split_num = 8; - const int32_t spatial_shapes_size = 64; - int32_t deal_num = PAD_DOWN( - (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM); - float *grad_output_nram = (float *)nram_buffer; - float *grad_output_nram_temp = (float *)nram_buffer + deal_num; - float *grad_weight = (float *)nram_buffer + 2 * deal_num; - float *grad_h_weight = (float *)nram_buffer + 3 * deal_num; - float *grad_w_weight = (float *)nram_buffer + 4 * deal_num; - float *top_grad = (float *)nram_buffer + 5 * deal_num; - float *top_grad_temp = (float *)nram_buffer + 6 * deal_num; - int32_t *spatial_shapes_nram = - (int32_t *)((float *)nram_buffer + 7 * deal_num); - float *sampling_loc_nram = - (float *)nram_buffer + 7 * deal_num + 2 * sizeof(int32_t); - const int32_t total_num = batch * num_query * num_heads * num_levels; - int32_t num_per_core = total_num / taskDim; - int32_t num_rem = total_num % taskDim; - num_per_core = num_per_core + int32_t(taskId < num_rem); - int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) - : (num_rem + taskId * num_per_core); - int32_t end_per_core = start_per_core + num_per_core; - const int32_t C_repeat = channels / deal_num; - const int32_t C_tail = channels % deal_num; - const int32_t qid_stride = num_heads * channels; - int32_t base_ptr = 0; - for (int32_t num_loop = start_per_core; num_loop < end_per_core; ++num_loop) { - const int32_t l_col = num_loop % num_levels; - const int32_t m_col = num_loop / num_levels % num_heads; - const int32_t q_col = num_loop / num_levels / num_heads % num_query; - const int32_t b_col = num_loop / num_query / num_heads / num_levels; - int32_t data_weight_ptr = num_loop * num_points; - int32_t data_loc_w_ptr = data_weight_ptr << 1; - const int32_t value_offset = b_col * spatial_size * num_heads * channels; - const int32_t level_start_id = data_level_start_index[l_col]; - int32_t spatial_h_ptr = l_col << 1; - int32_t grad_output_offset = b_col * num_query * num_heads * channels + - q_col * num_heads * channels + - m_col * channels; - __memcpy(spatial_shapes_nram, spatial_shapes + spatial_h_ptr, - 2 * sizeof(int32_t), GDRAM2NRAM); - const int32_t spatial_h = spatial_shapes_nram[0]; - const int32_t spatial_w = spatial_shapes_nram[1]; - const int32_t value_ptr_offset = value_offset + level_start_id * qid_stride; - const float *data_value_ptr = data_value + value_ptr_offset; - float *grad_value_ptr = grad_value + value_ptr_offset; - const int32_t grad_attn_weight_out = num_loop * num_points; - const int32_t grad_sampling_loc_out = num_loop * num_points * 2; - for (int32_t p_col = 0; p_col < num_points; ++p_col) { - __memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr, - 2 * LEN_FLOAT, GDRAM2NRAM); - const float loc_w = sampling_loc_nram[0]; - const float loc_h = sampling_loc_nram[1]; - const float weight = data_attn_weight[data_weight_ptr]; - const float h_im = loc_h * spatial_h - 0.5; - const float w_im = loc_w * spatial_w - 0.5; - if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) { - const int32_t h_low = floorf(h_im); - const int32_t w_low = floorf(w_im); - const int32_t h_high = h_low + 1; - const int32_t w_high = w_low + 1; - - const float lh = h_im - h_low; - const float lw = w_im - w_low; - const float hh = 1.0 - lh; - const float hw = 1.0 - lw; - - const int32_t w_stride = num_heads * channels; - const int32_t h_stride = spatial_w * w_stride; - const int32_t h_low_ptr_offset = h_low * h_stride; - const int32_t h_high_ptr_offset = h_low_ptr_offset + h_stride; - const int32_t w_low_ptr_offset = w_low * w_stride; - const int32_t w_high_ptr_offset = w_low_ptr_offset + w_stride; - - float w1 = hh * hw; - float w2 = hh * lw; - float w3 = lh * hw; - float w4 = lh * lw; - - for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) { - base_ptr = m_col * channels + C_loop * deal_num; - __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); - __memcpy(top_grad, - grad_output + grad_output_offset + C_loop * deal_num, - deal_num * LEN_FLOAT, GDRAM2NRAM); - msDeformAttnCol2imBilinear( - top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, - h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, - h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad, - weight, grad_h_weight, grad_w_weight, grad_value_ptr, - grad_output_nram, grad_weight, - grad_sampling_loc + grad_sampling_loc_out + p_col * 2, - grad_attn_weight + grad_attn_weight_out + p_col, - grad_output_nram_temp, deal_num, deal_num, data_value_ptr); - } - if (C_tail != 0) { - base_ptr = m_col * channels + C_repeat * deal_num; - __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); - __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); - __memcpy(top_grad, - grad_output + grad_output_offset + C_repeat * deal_num, - C_tail * LEN_FLOAT, GDRAM2NRAM); - msDeformAttnCol2imBilinear( - top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, - h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, - h_high_ptr_offset, w_high_ptr_offset, hh, hw, lh, lw, top_grad, - weight, grad_h_weight, grad_w_weight, grad_value_ptr, - grad_output_nram, grad_weight, - grad_sampling_loc + grad_sampling_loc_out + p_col * 2, - grad_attn_weight + grad_attn_weight_out + p_col, - grad_output_nram_temp, deal_num, C_tail, data_value_ptr); - } - } - data_weight_ptr += 1; - data_loc_w_ptr += 2; - } - } -} - -void __mlu_func__ computeGridMaskAndOffset( - float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w, - float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes, - float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low, - float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh, - float *nram_lw, float *nram_hh, float *nram_hw, - float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset, - float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1, - float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp, - float *nram_offset1, float *nram_offset2, float *nram_offset3, - float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp, - int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads, - const int32_t num_levels, const int32_t num_points, const int32_t w_stride, - const int32_t qid_stride) { -#if __BANG_ARCH__ >= 322 - // [num_levels, 2] --> [2, num_levels] - __bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2); - __bang_transpose(nram_loc_w, nram_grad_output_tl, - num_per_time_real * num_heads * num_levels, num_points); - __bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid, - num_per_time_real * num_heads * num_levels, num_points); - __bang_int322float((float *)nram_spatial_shapes, - (int32_t *)nram_spatial_shapes, num_levels * 2, 0); - __bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes, - num_levels, 2); - __bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride, - num_levels); - __memcpy_async(nram_spatial_shapes, nram_grad_output_tr, - num_levels * 2 * sizeof(float), NRAM2NRAM); - __bang_cycle_mul(nram_loc_w, nram_loc_w, - (float *)nram_spatial_shapes + num_levels, num_deal_grid, - num_levels); - __bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes), - num_deal_grid, num_levels); - __bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid); - __bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid); - // get mask. (h_im > -1 && w_im > -1 && - // h_im < spatial_h && w_im < spatial_w) - __bang_cycle_lt(nram_w_low_temp, nram_loc_w, - (float *)(nram_spatial_shapes + num_levels), num_deal_grid, - num_levels); - __bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes), - num_deal_grid, num_levels); - __bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid); - __bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid); - __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, - num_deal_grid); - __bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid); - __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, - num_deal_grid); - __bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points, - num_per_time_real * num_heads * num_levels); - __memcpy_async(nram_h_high_temp, nram_w_low_temp, - num_deal_grid * sizeof(float), NRAM2NRAM); - __bang_transpose(nram_grad_output_tl, nram_loc_w, num_points, - num_per_time_real * num_heads * num_levels); - __memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float), - NRAM2NRAM); - __bang_transpose(nram_grad_output_tl, nram_loc_h, num_points, - num_per_time_real * num_heads * num_levels); - __memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float), - NRAM2NRAM); - __bang_floor(nram_w_low, nram_loc_w, num_deal_grid); - __bang_floor(nram_h_low, nram_loc_h, num_deal_grid); - __bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid); - __bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid); - __bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid); - __bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid); - __bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid); - __bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid); - __bang_transpose(nram_h_low_ptr_offset, nram_h_low, - num_per_time_real * num_heads * num_levels, num_points); - __bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, - num_deal_grid, num_levels); - __bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, - num_deal_grid, num_levels); - __bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points, - num_per_time_real * num_heads * num_levels); - __memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset, - num_deal_grid * sizeof(float), NRAM2NRAM); - __bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points, - num_per_time_real * num_heads * num_levels); - __memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset, - num_deal_grid * sizeof(float), NRAM2NRAM); - __bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride, - num_deal_grid); - __bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride, - num_deal_grid); - __bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid); - __bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid); - __bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid); - __bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid); - __bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset, - num_deal_grid); - __bang_transpose(nram_offset_temp, nram_offset1, - num_per_time_real * num_heads, num_levels * num_points); - __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, - num_deal_grid, num_heads); - __bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points, - num_per_time_real * num_heads); - __bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset, - num_deal_grid); - __bang_transpose(nram_offset_temp, nram_offset2, - num_per_time_real * num_heads, num_levels * num_points); - __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, - num_deal_grid, num_heads); - __bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points, - num_per_time_real * num_heads); - __bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset, - num_deal_grid); - __bang_transpose(nram_offset_temp, nram_offset3, - num_per_time_real * num_heads, num_levels * num_points); - __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, - num_deal_grid, num_heads); - __bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points, - num_per_time_real * num_heads); - __bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset, - num_deal_grid); - __bang_transpose(nram_offset_temp, nram_offset4, - num_per_time_real * num_heads, num_levels * num_points); - __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, - num_deal_grid, num_heads); - __bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points, - num_per_time_real * num_heads); - // h_low >= 0 && w_low >= 0 mask2 - float *mask1 = nram_h_low_ptr_offset; - float *mask2 = nram_h_high_ptr_offset; - float *mask3 = nram_w_low_ptr_offset; - float *mask4 = nram_w_high_ptr_offset; - __bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid); - __bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid); - __bang_and(mask2, mask1, mask2, num_deal_grid); - __bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid); - // h_low >= 0 && w_high <= width - 1 mask1 - __bang_transpose(mask3, nram_w_high, - num_per_time_real * num_heads * num_levels, num_points); - __bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1, - num_levels * 2); - __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels), - num_deal_grid, num_levels); - __bang_transpose(mask4, mask3, num_points, - num_per_time_real * num_heads * num_levels); - __bang_and(mask1, mask1, mask4, num_deal_grid); - __bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid); - // h_high <= height - 1 && w_high <= width - 1 mask3 - __bang_transpose(mask3, nram_h_high, - num_per_time_real * num_heads * num_levels, num_points); - __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid, - num_levels); - - __bang_transpose(nram_h_low_temp, mask3, num_points, - num_per_time_real * num_heads * num_levels); - __bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid); - __bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid); - // h_high <= height - 1 && w_low >= 0 mask4 - __bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid); - __bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid); - __bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid); -#endif -} - -void __mlu_func__ loadValue( - float *nram_grad_output_tl, float *nram_grad_output_tr, - float *nram_grad_output_bl, float *nram_grad_output_br, - const float *data_value, const float *grad_output, float *grad_temp1, - float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4, - float *nram_offset1, float *nram_offset2, float *nram_offset3, - float *nram_offset4, float *nram_grad_weight, - int32_t *nram_level_start_index, int32_t offset_nram, - int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory, - int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real, - int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels, - const int32_t num_points, int32_t grid_offset, const int32_t spatial_size, - const int32_t qid_stride) { -#if __BANG_ARCH__ >= 322 - int32_t value_offset_temp = 0; - __bang_write_zero(nram_grad_output_tl, 4 * offset_nram); - __sync_io_move_compute(); - __memcpy_async( - grad_temp2, - grad_output + (start_per_core + grid_loop * num_per_time_theory) * - num_heads * deal_num_real, - num_per_time_real * num_heads * deal_num_real * sizeof(float), - GDRAM2NRAM); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - const int32_t b_col = - (grid_offset + loop) / num_query / num_heads / num_levels / num_points; - const int32_t l_col = (grid_offset + loop) / num_points % num_levels; - const int32_t level_start_id = nram_level_start_index[l_col]; - value_offset_temp = - b_col * spatial_size * qid_stride + level_start_id * qid_stride; - if (mask2[loop]) { - __memcpy_async( - nram_grad_output_tl + loop * deal_num_real, - data_value + value_offset_temp + int32_t(nram_offset1[loop]), - deal_num_real * sizeof(float), GDRAM2NRAM); - } - if (mask1[loop]) { - __memcpy_async( - nram_grad_output_tr + loop * deal_num_real, - data_value + value_offset_temp + int32_t(nram_offset2[loop]), - deal_num_real * sizeof(float), GDRAM2NRAM); - } - if (mask4[loop]) { - __memcpy_async( - nram_grad_output_bl + loop * deal_num_real, - data_value + value_offset_temp + int32_t(nram_offset3[loop]), - deal_num_real * sizeof(float), GDRAM2NRAM); - } - if (mask3[loop]) { - __memcpy_async( - nram_grad_output_br + loop * deal_num_real, - data_value + value_offset_temp + int32_t(nram_offset4[loop]), - deal_num_real * sizeof(float), GDRAM2NRAM); - } - } - for (int32_t m = 0; m < deal_num_real; ++m) { - __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, - num_deal_grid * sizeof(float), NRAM2NRAM); - } - __sync_io_move_compute(); -#endif -} - -void __mlu_func__ computeGradValue( - float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4, - float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1, - float *nram_offset2, float *nram_offset3, float *nram_offset4, - int32_t *nram_level_start_index, int32_t deal_num_real, - const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3, - float *nram_w4, int32_t num_per_time_real, const int32_t num_heads, - const int32_t num_levels, const int32_t num_points, const int32_t num_query, - int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size, - const int32_t qid_stride, float *nram_grid_offset1, - float *nram_grid_offset2) { -#if __BANG_ARCH__ >= 322 - __bang_transpose(grad_temp3, grad_temp1, - deal_num_real * num_per_time_real * num_heads, - num_levels * num_points); - __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, - deal_num_real); - __bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1, - num_deal_grid * deal_num_real, - deal_num_real * num_per_time_real * num_heads); - __bang_transpose(grad_temp4, grad_temp3, num_levels * num_points, - deal_num_real * num_per_time_real * num_heads); - __bang_cycle_mul(grad_temp1, grad_temp4, nram_w1, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads / - num_levels / num_points) * - spatial_size * qid_stride; - } - __bang_transpose(nram_grid_offset2, nram_grid_offset1, - num_per_time_real * num_heads * num_levels, num_points); - __bang_int322float((float *)nram_level_start_index, nram_level_start_index, - num_levels, 0); - __bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index, - qid_stride, num_levels); - __bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1, - num_deal_grid, num_levels); - __bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points, - num_per_time_real * num_heads * num_levels); - __bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid); - __bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid); - __bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid); - __bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - if (mask2[loop]) { - __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), - (float *)(grad_value + int32_t(nram_offset1[loop])), - (float *)(grad_temp3 + loop * deal_num_real), - deal_num_real); - } - } - __bang_cycle_mul(grad_temp1, grad_temp4, nram_w2, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - if (mask1[loop]) { - __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), - (float *)(grad_value + int32_t(nram_offset2[loop])), - (float *)(grad_temp3 + loop * deal_num_real), - deal_num_real); - } - } - __bang_cycle_mul(grad_temp1, grad_temp4, nram_w3, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - if (mask4[loop]) { - __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), - (float *)(grad_value + int32_t(nram_offset3[loop])), - (float *)(grad_temp3 + loop * deal_num_real), - deal_num_real); - } - } - - __bang_cycle_mul(grad_temp1, grad_temp4, nram_w4, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); - for (int32_t loop = 0; loop < num_deal_grid; ++loop) { - if (mask3[loop]) { - __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), - (float *)(grad_value + int32_t(nram_offset4[loop])), - (float *)(grad_temp3 + loop * deal_num_real), - deal_num_real); - } - } -#endif -} - -void __mlu_func__ computeGradAttnWeight( - float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl, - float *nram_grad_output_tr, float *nram_grad_output_bl, - float *nram_grad_output_br, float *grad_temp1, float *grad_temp2, - const float *grad_attn_weight, float *nram_hw, float *nram_hh, - float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1, - float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram, - int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real, - const int32_t num_heads, const int32_t num_levels, const int32_t num_points, - int32_t grid_offset, float *nram_h_high_temp) { -#if __BANG_ARCH__ >= 322 - __bang_write_zero(grad_w_weight, 2 * offset_nram); - - // grad_output_nram_tl - __bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid, - deal_num_real); - __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1, - num_deal_grid * deal_num_real, num_deal_grid); - // nram_grad_output_tr - __bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid, - deal_num_real); - __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr, - num_deal_grid * deal_num_real); - // nram_grad_output_tl - __bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid, - deal_num_real); - __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl, - num_deal_grid * deal_num_real); - // nram_grad_output_br - __bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid, - deal_num_real); - __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br, - num_deal_grid * deal_num_real); - __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4, - num_deal_grid * deal_num_real, num_deal_grid); - __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br, - num_deal_grid * deal_num_real); - __bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real, - num_deal_grid); - __bang_transpose(nram_grad_output_tr, nram_grad_output_br, - num_per_time_real * num_heads, - num_points * num_levels * deal_num_real); - __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, - deal_num_real); - __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, - num_deal_grid * deal_num_real, - num_per_time_real * num_heads * deal_num_real); - __bang_transpose(nram_grad_output_br, nram_grad_output_tr, - num_points * num_levels * deal_num_real, - num_per_time_real * num_heads); - - __bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br, - num_deal_grid, deal_num_real); - recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real, - ALIGN_NUM); - __bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid, - 0); - __nram__ int table[2] = {0, (int)0xffffffff}; - __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, - num_deal_grid, 64); - __bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr, - (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); - - __bang_atomic_add((float *)nram_grad_output_tr, - (float *)grad_attn_weight + grid_offset, - (float *)nram_grad_output_tr, num_deal_grid); -#endif -} - -void __mlu_func__ computeGradSampingLoc( - const float *grad_sampling_loc, float *nram_grad_output_tl, - float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight, - int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2, - float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real, - int32_t num_per_time_real, const int32_t num_heads, - const int32_t num_levels, const int32_t num_points, int32_t grid_offset, - float *nram_h_high_temp) { -#if __BANG_ARCH__ >= 322 - __bang_transpose(nram_grad_output_tl, grad_h_weight, - num_per_time_real * num_heads * num_levels * deal_num_real, - num_points); - __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, - (float *)nram_spatial_shapes, num_deal_grid * deal_num_real, - num_levels); - __bang_transpose(grad_h_weight, nram_grad_output_tl, - num_points * deal_num_real, - num_per_time_real * num_heads * num_levels); - for (int32_t m = 0; m < deal_num_real; ++m) { - __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, - num_deal_grid * sizeof(float), NRAM2NRAM); - } - __sync_io_move_compute(); - __bang_transpose(nram_grad_output_tr, grad_temp1, - deal_num_real * num_per_time_real * num_heads, - num_levels * num_points); - __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, - deal_num_real); - __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, - num_deal_grid * deal_num_real, - deal_num_real * num_per_time_real * num_heads); - __bang_transpose(grad_temp1, nram_grad_output_tr, - num_levels * num_points * deal_num_real, - num_per_time_real * num_heads); - __bang_mul(grad_h_weight, grad_h_weight, grad_temp1, - num_deal_grid * deal_num_real); - __bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid, - deal_num_real); - __memcpy_async(grad_h_weight, nram_grad_output_tl, - num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); - recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM); - __nram__ int table[2] = {0, (int)0xffffffff}; - __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, - num_deal_grid, 64); - __bang_band((char *)grad_h_weight, (char *)grad_h_weight, - (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); - __bang_transpose(nram_grad_output_tl, grad_w_weight, - num_per_time_real * num_heads * num_levels * deal_num_real, - num_points); - __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, - (float *)(nram_spatial_shapes + num_levels), - num_deal_grid * deal_num_real, num_levels); - __bang_transpose(grad_w_weight, nram_grad_output_tl, - num_points * deal_num_real, - num_per_time_real * num_heads * num_levels); - __bang_mul(grad_w_weight, grad_w_weight, grad_temp1, - num_deal_grid * deal_num_real); - __bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid, - deal_num_real); - __memcpy(grad_w_weight, nram_grad_output_tl, - num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); - recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM); - __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, - num_deal_grid, 64); - __bang_band((char *)grad_w_weight, (char *)grad_w_weight, - (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); - - __memcpy(grad_w_weight + num_deal_grid, grad_h_weight, - num_deal_grid * sizeof(float), NRAM2NRAM); - __bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid); - __bang_atomic_add((float *)nram_grad_output_tl, - (float *)grad_sampling_loc + grid_offset * 2, - (float *)nram_grad_output_tl, 2 * num_deal_grid); - -#endif -} - -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( - const float *data_value, const int32_t *spatial_shapes, - const int32_t *data_level_start_index, const float *data_sampling_loc, - const float *data_attn_weight, const float *grad_output, - const int32_t batch, const int32_t spatial_size, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_query, - const int32_t num_points, float *grad_value, float *grad_sampling_loc, - float *grad_attn_weight) { -#if __BANG_ARCH__ > 322 - const int32_t split_grid_num = 28; - const int32_t split_num_c = 8; - const int32_t C_align = PAD_UP(channels, ALIGN_NUM); - - const int32_t num_hlp = num_heads * num_levels * num_points; - int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) - - 3 * num_levels * sizeof(int32_t)) / - sizeof(float) / - (split_num_c * C_align + split_grid_num) / - PAD_UP((num_hlp), ALIGN_NUM); - - int32_t deal_grid_num_theory = num_per_time_theory * num_hlp; - - const int32_t offset_nram = num_per_time_theory * C_align * num_hlp; - const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM); - float *nram_grad_output_tl = (float *)nram_buffer; - float *nram_grad_output_tr = (float *)nram_buffer + offset_nram; - float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram; - float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram; - - float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram; - float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram; - float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram; - float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram; - - float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram; - float *nram_loc_h = - (float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc; - float *nram_h_low = - (float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc; - float *nram_w_low = - (float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc; - float *nram_h_high = - (float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc; - float *nram_w_high = - (float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc; - float *nram_h_low_temp = - (float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc; - float *nram_h_high_temp = - (float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc; - - float *nram_hw = - (float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc; - float *nram_hh = - (float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc; - float *nram_lw = - (float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc; - float *nram_lh = - (float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc; - - float *nram_h_low_ptr_offset = - (float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc; - float *nram_h_high_ptr_offset = - (float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc; - float *nram_w_low_ptr_offset = - (float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc; - float *nram_w_high_ptr_offset = - (float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc; - - float *nram_w1 = - (float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc; - float *nram_w2 = - (float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc; - float *nram_w3 = - (float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc; - float *nram_w4 = - (float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc; - - float *nram_grad_weight = - (float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc; - float *nram_base_ptr = - (float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc; - float *nram_offset_temp = - (float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc; - - float *nram_offset1 = - (float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc; - float *nram_offset2 = - (float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc; - float *nram_offset3 = - (float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc; - float *nram_offset4 = - (float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc; - - float *nram_w_low_temp = - (float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc; - int32_t *nram_spatial_shapes = - (int32_t *)((float *)nram_buffer + split_num_c * offset_nram + - 28 * offset_nram_calc); - int32_t *nram_level_start_index = - (int32_t *)(nram_spatial_shapes + 2 * num_levels); - float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels); - const int32_t total_num = batch * num_query; - int32_t num_per_core = total_num / taskDim; - int32_t num_rem = total_num % taskDim; - num_per_core = num_per_core + int32_t(taskId < num_rem); - num_per_time_theory = - num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core; - int32_t num_deal_grid = num_per_time_theory * num_hlp; - - if (num_per_core == 0) return; - int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) - : (num_rem + taskId * num_per_core); - - const int32_t qid_stride = num_heads * channels; - int32_t deal_num_real = channels; - - const int32_t repeat_times = num_per_core / num_per_time_theory; - const int32_t tail_num = num_per_core % num_per_time_theory; - - int32_t num_per_time_real = num_per_time_theory; - - for (int32_t loop = 0; loop < num_heads; ++loop) { - nram_base_ptr[loop] = loop * channels; - } - const int32_t w_stride = num_heads * channels; - for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) { - int32_t grid_offset = - (start_per_core + grid_loop * num_per_time_theory) * num_hlp; - if (grid_loop == repeat_times) { - if (tail_num == 0) { - continue; - } else { - grid_offset = - (start_per_core + repeat_times * num_per_time_theory) * num_hlp; - num_per_time_real = tail_num; - num_deal_grid = tail_num * num_hlp; - } - } - - __memcpy_async(nram_spatial_shapes, spatial_shapes, - num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); - __memcpy_async(nram_level_start_index, data_level_start_index, - num_levels * sizeof(int32_t), GDRAM2NRAM); - __memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2, - num_deal_grid * 2 * sizeof(float), GDRAM2NRAM); - __memcpy(nram_grad_weight, data_attn_weight + grid_offset, - num_deal_grid * sizeof(float), GDRAM2NRAM); - computeGridMaskAndOffset( - nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h, - nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp, - nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw, - nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset, - nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2, - nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2, - nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp, - num_deal_grid, num_per_time_real, num_heads, num_levels, num_points, - w_stride, qid_stride); - float *mask1 = nram_h_low_ptr_offset; - float *mask2 = nram_h_high_ptr_offset; - float *mask3 = nram_w_low_ptr_offset; - float *mask4 = nram_w_high_ptr_offset; - loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl, - nram_grad_output_br, data_value, grad_output, grad_temp1, - grad_temp2, mask1, mask2, mask3, mask4, nram_offset1, - nram_offset2, nram_offset3, nram_offset4, nram_grad_weight, - nram_level_start_index, offset_nram, start_per_core, grid_loop, - num_per_time_theory, num_heads, deal_num_real, num_per_time_real, - num_deal_grid, num_query, num_levels, num_points, grid_offset, - spatial_size, qid_stride); - float *nram_grid_offset1 = nram_loc_h; - float *nram_grid_offset2 = nram_loc_w; - computeGradValue( - grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3, - mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4, - nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2, - nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points, - num_query, num_deal_grid, grid_offset, spatial_size, qid_stride, - nram_grid_offset1, nram_grid_offset2); - - // compute grad_weight - float *grad_weight = grad_temp1; - float *grad_h_weight = grad_temp4; - float *grad_w_weight = grad_temp3; - computeGradAttnWeight( - grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr, - nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2, - grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight, - nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid, - deal_num_real, num_per_time_real, num_heads, num_levels, num_points, - grid_offset, nram_h_high_temp); - - // compute grad_sampling_loc - computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl, - nram_grad_output_tr, grad_h_weight, grad_w_weight, - nram_spatial_shapes, grad_temp1, grad_temp2, - nram_grad_weight, num_deal_grid, deal_num_real, - num_per_time_real, num_heads, num_levels, num_points, - grid_offset, nram_h_high_temp); - } -#endif -} - -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( - const float *data_value, const int32_t *spatial_shapes, - const int32_t *data_level_start_index, const float *data_sampling_loc, - const float *data_attn_weight, const float *grad_output, - const int32_t batch, const int32_t spatial_size, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_query, - const int32_t num_points, float *grad_value, float *grad_sampling_loc, - float *grad_attn_weight); - -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( - const float *data_value, const int32_t *spatial_shapes, - const int32_t *data_level_start_index, const float *data_sampling_loc, - const float *data_attn_weight, const float *grad_output, - const int32_t batch, const int32_t spatial_size, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_query, - const int32_t num_points, float *grad_value, float *grad_sampling_loc, - float *grad_attn_weight); - -void KernelMsDeformAttnBackwardDefaultKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float *data_value, - const int32_t *spatial_shapes, const int32_t *data_level_start_index, - const float *data_sampling_loc, const float *data_attn_weight, - const float *grad_output, const int32_t batch, const int32_t spatial_size, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_query, const int32_t num_points, float *grad_value, - float *grad_sampling_loc, float *grad_attn_weight) { - MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<>>( - data_value, spatial_shapes, data_level_start_index, data_sampling_loc, - data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, - num_levels, num_query, num_points, grad_value, grad_sampling_loc, - grad_attn_weight); -} - -void KernelMsDeformAttnBackwardSmallChannelsKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float *data_value, - const int32_t *spatial_shapes, const int32_t *data_level_start_index, - const float *data_sampling_loc, const float *data_attn_weight, - const float *grad_output, const int32_t batch, const int32_t spatial_size, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_query, const int32_t num_points, float *grad_value, - float *grad_sampling_loc, float *grad_attn_weight) { - MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<>>( - data_value, spatial_shapes, data_level_start_index, data_sampling_loc, - data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, - num_levels, num_query, num_points, grad_value, grad_sampling_loc, - grad_attn_weight); -} diff --git a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu deleted file mode 100644 index dcc722d854..0000000000 --- a/mmcv/ops/csrc/common/mlu/nms_mlu_kernel.mlu +++ /dev/null @@ -1,483 +0,0 @@ -/************************************************************************* - * Copyright (C) 2021 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "nms_utils.hpp" - -#define COORD_DIM (4) - -#define SIZE_NRAM_BUF (MAX_NRAM_SIZE + REM_FOR_STACK - 62 * 1024) -#define SIZE_SRAM_BUF (MAX_SRAM_SIZE) - -__nram__ int8_t nram_buffer[SIZE_NRAM_BUF]; -__mlu_shared__ int8_t sram_buffer[SIZE_SRAM_BUF]; - -enum Addr { SRAM, GDRAM }; - -template -__mlu_func__ void nms_detection( - uint32_t &output_box_num, const int output_mode, OUT_DT *output_dram, - IN_DT *input_data_score, const IN_DT *input_data_box, const Addr input_ram, - IN_DT *sram, const int core_limit, const int input_num_boxes, - const int max_output_size, const float thresh_iou, const float thresh_score, - const float offset, const int algo) { - // global value - int32_t *exit_flag = (int32_t *)(sram + 28); - exit_flag[0] = 0; - // score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2 - int nms_buffer_count1 = 9; - // temp nram buffer to store selected target. - int nram_save_limit_count = 256; - float div_thresh_iou = 1.0 / thresh_iou; - - // input data ptr - const IN_DT *input_x1_ptr = input_data_box; - const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes; - const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes; - const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes; - - int limit = 0; // find limit when GDRAM or SRAM - int max_seg_pad = 0; // the max length every repeat - int repeat = 0; - int remain = 0; - int remain_pad = 0; - int input_offset = 0; // offset of input_data for current core - int nram_save_count = 0; - - if (output_mode == 0) { - limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - - nram_save_limit_count * sizeof(OUT_DT)) / - (nms_buffer_count1 * sizeof(IN_DT)); - } else { - // 5 maens: score, x1, y1, x2, y2 - limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - - nram_save_limit_count * 5 * sizeof(OUT_DT)) / - (nms_buffer_count1 * sizeof(IN_DT)); - } - - int max_seg_iou_compute = 0; - int repeat_iou_compute = 0; - int remain_iou_compute = 0; - int remain_pad_iou_compute = 0; - - getComputeParamsBlockOrU1(sizeof(IN_DT), input_num_boxes, limit, core_limit, - input_offset, max_seg_pad, repeat, remain, - remain_pad, max_seg_iou_compute, repeat_iou_compute, - remain_iou_compute, remain_pad_iou_compute); - - // init the data ptr - IN_DT *score = (IN_DT *)nram_buffer; - IN_DT *x1 = score + max_seg_pad; - IN_DT *y1 = x1 + max_seg_pad; - IN_DT *x2 = y1 + max_seg_pad; - IN_DT *y2 = x2 + max_seg_pad; - IN_DT *inter_x1 = y2 + max_seg_pad; - IN_DT *inter_y1 = inter_x1 + max_seg_pad; - IN_DT *inter_x2 = inter_y1 + max_seg_pad; - IN_DT *inter_y2 = inter_x2 + max_seg_pad; - IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2 - OUT_DT *nram_save = - (OUT_DT *)((char *)max_box + - NFU_ALIGN_SIZE); // offset two line from max_box - -#if __BANG_ARCH__ >= 300 - float max_box_x1 = 0; - float max_box_y1 = 0; - float max_box_x2 = 0; - float max_box_y2 = 0; -#endif - mluMemcpyDirection_t load_dir = SRAM2NRAM; - mluMemcpyDirection_t store_dir = NRAM2SRAM; - load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM; - store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM; - - for (int keep = 0; keep < max_output_size; - keep++) { // loop until the max_score <= 0 - if (core_limit != 1) { - __sync_cluster(); // sync before current loop - } - - /******FIND MAX START******/ - int max_index = 0; // the max score index - int global_max_index = 0; // for U1 - float max_area = 0; // the max socre area - max_box[0] = 0; // init 0 - findCoreMaxBox(input_data_score, score, inter_x1, max_box, input_x1_ptr, - input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir, - input_offset, repeat, remain, remain_pad, max_seg_pad, - max_index); - - if (core_limit == 1) { -#if __BANG_ARCH__ >= 300 - calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1, - max_box_x2, max_box_y2); -#else - calMaxArea(max_box, algo, offset, max_area); -#endif - input_data_score[max_index] = 0; - global_max_index = max_index; - } else if (core_limit == 4) { - __sync_cluster(); - findClusterMaxBox(sram, max_box, inter_x1, input_data_score, core_limit); - -#if __BANG_ARCH__ >= 300 - calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1, - max_box_x2, max_box_y2); -#else - calMaxArea(max_box, algo, offset, max_area); -#endif - global_max_index = ((uint32_t *)(max_box + 5))[0]; - input_data_score[global_max_index] = 0; - } - // by now, we get: max_score|max_index|max_box|max_area - /******FIND MAX END******/ - - storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count, - max_output_size, thresh_score, output_mode, nram_save_count, - output_box_num); - - // if the max score <= 0, end - if (core_limit == 1) { - if (float(max_box[0]) <= thresh_score) { - break; - } - } else { - if (float(max_box[0]) <= thresh_score) { - if (coreId == 0) { - exit_flag[0] = 1; - } - } - __sync_cluster(); - if (exit_flag[0] == 1) { - break; - } - } -/******NMS STORE END******/ -#if __BANG_ARCH__ >= 300 - scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr, - input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, - inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box_x1, - max_box_y1, max_box_x2, max_box_y2, nram_save, - repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute, - max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou, - input_offset, offset, max_area, input_num_boxes, algo); -#else - scoreUpdate(input_data_score, load_dir, store_dir, input_x1_ptr, - input_y1_ptr, input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, - inter_x1, inter_y1, inter_x2, inter_y2, max_box, max_box[1], - max_box[2], max_box[3], max_box[4], nram_save, - repeat_iou_compute, remain_iou_compute, remain_pad_iou_compute, - max_seg_iou_compute, max_seg_pad, thresh_iou, div_thresh_iou, - input_offset, offset, max_area, input_num_boxes, algo); -#endif - } // for max_output_size -} - -__mlu_global__ void MLUUnion1KernelNMS( - const void *input_boxes, const void *input_confidence, - const int input_num_boxes, const int max_output_size, - const float iou_threshold, const float confidence_threshold, - const int output_mode, void *workspace, void *result_num, void *output, - const cnrtDataType_t data_type_input, const float offset, const int algo) { - if (data_type_input == CNRT_FLOAT16) { - __memcpy(workspace, input_confidence, input_num_boxes * sizeof(half), - GDRAM2GDRAM); - } else if (data_type_input == CNRT_FLOAT32) { - __memcpy(workspace, input_confidence, input_num_boxes * sizeof(float), - GDRAM2GDRAM); - } else { - } - - uint32_t output_box_num = 0; - float *score_data = (float *)workspace; - float *boxes_data = (float *)input_boxes; - float *sram = (float *)sram_buffer; - - if (output_mode == 0) { - if (data_type_input == CNRT_FLOAT32) { - nms_detection(output_box_num, output_mode, (uint32_t *)output, score_data, - boxes_data, GDRAM, sram, taskDim, input_num_boxes, - max_output_size, iou_threshold, confidence_threshold, - offset, algo); - } else { - nms_detection(output_box_num, output_mode, (uint32_t *)output, - (half *)score_data, (half *)boxes_data, GDRAM, (half *)sram, - taskDim, input_num_boxes, max_output_size, iou_threshold, - confidence_threshold, offset, algo); - } - } else { - if (data_type_input == CNRT_FLOAT32) { - nms_detection(output_box_num, output_mode, (float *)output, score_data, - boxes_data, GDRAM, sram, taskDim, input_num_boxes, - max_output_size, iou_threshold, confidence_threshold, - offset, algo); - } else { - nms_detection(output_box_num, output_mode, (half *)output, - (half *)score_data, (half *)boxes_data, GDRAM, (half *)sram, - taskDim, input_num_boxes, max_output_size, iou_threshold, - confidence_threshold, offset, algo); - } - } - ((uint32_t *)result_num)[0] = output_box_num; -} - -template -__mlu_func__ void nms_detection_ux( - int32_t *exit_flag, uint32_t &output_box_num, OUT_DT *output_dram, - IN_DT *score_data, const IN_DT *boxes_data, const Addr input_ram, - const int input_num_boxes, const int max_output_size, - const float thresh_iou, const float thresh_score, const float offset, - const int output_mode, const int algo, char *cdma_gdram) { - exit_flag[0] = 0; - - IN_DT *sram = (IN_DT *)sram_buffer; - - // score, x1, y1, x2, y2, inter_x1, inter_y1, inter_x2, inter_y2 - int nms_buffer_count1 = 9; - // temp nram buffer to store selected target. - int nram_save_limit_count = 256; - float div_thresh_iou = 1.0 / thresh_iou; - - // input data ptr - const IN_DT *input_x1_ptr = boxes_data; - const IN_DT *input_y1_ptr = input_x1_ptr + input_num_boxes; - const IN_DT *input_x2_ptr = input_y1_ptr + input_num_boxes; - const IN_DT *input_y2_ptr = input_x2_ptr + input_num_boxes; - - int limit = 0; // find limit when GDRAM or SRAM - int max_seg_pad = 0; // the max length every repeat - int repeat = 0; - int remain = 0; - int remain_pad = 0; - int nram_save_count = 0; - - if (output_mode == 0) { - limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - - nram_save_limit_count * sizeof(OUT_DT)) / - (nms_buffer_count1 * sizeof(IN_DT)); - } else { - limit = (SIZE_NRAM_BUF - NFU_ALIGN_SIZE /*for max_box*/ * sizeof(IN_DT) - - nram_save_limit_count * INFO_NUM * sizeof(OUT_DT)) / - (nms_buffer_count1 * sizeof(IN_DT)); - } - - int input_offset = 0; - int max_seg_iou_compute = 0; - int repeat_iou_compute = 0; - int remain_iou_compute = 0; - int remain_pad_iou_compute = 0; - - getComputeParamsUx(sizeof(IN_DT), input_num_boxes, limit, input_offset, - max_seg_pad, repeat, remain, remain_pad, - max_seg_iou_compute, repeat_iou_compute, - remain_iou_compute, remain_pad_iou_compute); - // init the nram ptr - IN_DT *score = (IN_DT *)nram_buffer; - IN_DT *x1 = score + max_seg_pad; - IN_DT *y1 = x1 + max_seg_pad; - IN_DT *x2 = y1 + max_seg_pad; - IN_DT *y2 = x2 + max_seg_pad; - IN_DT *inter_x1 = y2 + max_seg_pad; - IN_DT *inter_y1 = inter_x1 + max_seg_pad; - IN_DT *inter_x2 = inter_y1 + max_seg_pad; - IN_DT *inter_y2 = inter_x2 + max_seg_pad; - IN_DT *max_box = inter_y2 + max_seg_pad; // the max score, x1, y1, x2, y2 - OUT_DT *nram_save = - (OUT_DT *)((char *)max_box + - NFU_ALIGN_SIZE); // offset two line from max_box -#if __BANG_ARCH__ >= 300 - float max_box_x1 = 0; - float max_box_y1 = 0; - float max_box_x2 = 0; - float max_box_y2 = 0; -#endif - mluMemcpyDirection_t load_dir = SRAM2NRAM; - mluMemcpyDirection_t store_dir = NRAM2SRAM; - load_dir = (input_ram == SRAM) ? SRAM2NRAM : GDRAM2NRAM; - store_dir = (input_ram == SRAM) ? NRAM2SRAM : NRAM2GDRAM; - - for (int keep = 0; keep < max_output_size; - keep++) { // loop until the max_score <= 0 - __sync_all(); - - int max_index = 0; - int global_max_index = 0; // for Ux - float max_area = 0; // the max socre area - max_box[0] = 0; // init 0 - - if (coreId == 0) { - findCoreMaxBox(score_data, score, inter_x1, max_box, input_x1_ptr, - input_y1_ptr, input_x2_ptr, input_y2_ptr, load_dir, - input_offset, repeat, remain, remain_pad, max_seg_pad, - max_index); - // copy max box info to sram - __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); - } - __sync_all(); -#if __BANG_ARCH__ >= 590 - __memcpy((char *)cdma_gdram + REDUCE_NUM * clusterId * sizeof(IN_DT), sram, - REDUCE_NUM * sizeof(IN_DT), SRAM2GDRAM); - __sync_all(); - if (clusterId == 0 && coreId == 0) { - __bang_write_zero(inter_x1, NMS_SIZE); - __memcpy((char *)inter_x1, (char *)cdma_gdram, sizeof(IN_DT), GDRAM2NRAM, - sizeof(IN_DT), REDUCE_NUM * sizeof(IN_DT), clusterDim - 1); - __bang_max(max_box, inter_x1, NMS_SIZE); - int max_cluster = (sizeof(IN_DT) == sizeof(half)) - ? ((uint16_t *)max_box)[1] - : ((uint32_t *)max_box)[1]; - __memcpy((char *)cdma_gdram, - (char *)cdma_gdram + max_cluster * REDUCE_NUM * sizeof(IN_DT), - REDUCE_NUM * sizeof(IN_DT), GDRAM2GDRAM); - } - __sync_all(); - __memcpy(max_box, cdma_gdram, REDUCE_NUM * sizeof(IN_DT), GDRAM2NRAM); -#else - findGlobalMaxBox(max_box, sram, inter_x1); -#endif - -#if __BANG_ARCH__ >= 300 - calMaxArea(max_box, algo, offset, max_area, max_box_x1, max_box_y1, - max_box_x2, max_box_y2); -#else - calMaxArea(max_box, algo, offset, max_area); -#endif - global_max_index = ((uint32_t *)(max_box + 5))[0]; - if (coreId != MEMORY_CORE) { - score_data[global_max_index] = 0; - } - - storeResult(max_box, nram_save, output_dram, keep, nram_save_limit_count, - max_output_size, thresh_score, output_mode, nram_save_count, - output_box_num); - - if (float(max_box[0]) <= thresh_score) { - if (clusterId == 0 && coreId == 0) { - exit_flag[0] = 1; // dram - } - } - __sync_all(); - if (exit_flag[0] == 1) { - break; - } -/******NMS STORE END******/ -#if __BANG_ARCH__ >= 300 - scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr, - input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1, - inter_y1, inter_x2, inter_y2, max_box, max_box_x1, max_box_y1, - max_box_x2, max_box_y2, nram_save, repeat_iou_compute, - remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute, - max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset, - max_area, input_num_boxes, algo); -#else - scoreUpdate(score_data, load_dir, store_dir, input_x1_ptr, input_y1_ptr, - input_x2_ptr, input_y2_ptr, x1, y1, x2, y2, score, inter_x1, - inter_y1, inter_x2, inter_y2, max_box, max_box[1], max_box[2], - max_box[3], max_box[4], nram_save, repeat_iou_compute, - remain_iou_compute, remain_pad_iou_compute, max_seg_iou_compute, - max_seg_pad, thresh_iou, div_thresh_iou, input_offset, offset, - max_area, input_num_boxes, algo); -#endif - } // for max_output_size -} - -__mlu_global__ void MLUUionXKernelNMS( - const void *input_boxes, const void *input_confidence, - const int input_num_boxes, const int max_output_size, - const float iou_threshold, const float confidence_threshold, - const float offset, const cnrtDataType_t data_type_input, - const int output_mode, const int algo, void *workspace, void *result_num, - void *output) { - int input_dwidth = (data_type_input == CNRT_FLOAT32) ? 4 : 2; - int32_t *exit_flag = (int32_t *)((char *)workspace + - INFO_NUM * input_num_boxes * input_dwidth); - char *cdma_addr = (char *)exit_flag + sizeof(int32_t); - int reduce_sram_size = NFU_ALIGN_SIZE * REDUCE_NUM * input_dwidth; - int availbale_sram_size = SIZE_SRAM_BUF - reduce_sram_size; - - int cluster_score_size = input_num_boxes * input_dwidth; - int cluster_boxes_size = input_num_boxes * 4 * input_dwidth; - char *sram_score = (char *)sram_buffer + reduce_sram_size; - char *sram_boxes = - (char *)sram_buffer + reduce_sram_size + cluster_score_size; - Addr input_ram = GDRAM; - if ((cluster_score_size + cluster_boxes_size) < availbale_sram_size) { - input_ram = SRAM; - __memcpy(sram_score, input_confidence, cluster_score_size, GDRAM2SRAM); - __memcpy(sram_boxes, input_boxes, cluster_boxes_size, GDRAM2SRAM); - } else { - __memcpy(workspace, input_confidence, cluster_score_size, GDRAM2GDRAM); - } - __sync_cluster(); - - uint32_t output_box_num = 0; - float *score_data; - float *boxes_data; - score_data = (input_ram == SRAM) ? (float *)sram_score : (float *)workspace; - boxes_data = (input_ram == SRAM) ? (float *)sram_boxes : (float *)input_boxes; - - if (output_mode == 0) { - if (data_type_input == CNRT_FLOAT32) { - nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output, - score_data, boxes_data, input_ram, input_num_boxes, - max_output_size, iou_threshold, confidence_threshold, - offset, output_mode, algo, cdma_addr); - } else { - nms_detection_ux(exit_flag, output_box_num, (uint32_t *)output, - (half *)score_data, (half *)boxes_data, input_ram, - input_num_boxes, max_output_size, iou_threshold, - confidence_threshold, offset, output_mode, algo, - cdma_addr); - } - } else { - if (data_type_input == CNRT_FLOAT32) { - nms_detection_ux(exit_flag, output_box_num, (float *)output, score_data, - boxes_data, input_ram, input_num_boxes, max_output_size, - iou_threshold, confidence_threshold, offset, output_mode, - algo, cdma_addr); - } else { - nms_detection_ux(exit_flag, output_box_num, (half *)output, - (half *)score_data, (half *)boxes_data, input_ram, - input_num_boxes, max_output_size, iou_threshold, - confidence_threshold, offset, output_mode, algo, - cdma_addr); - } - } - ((uint32_t *)result_num)[0] = output_box_num; -} - -void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t data_type_input, const void *boxes_ptr, - const void *scores_ptr, const int input_num_boxes, - const int max_output_boxes, const float iou_threshold, - const float offset, void *workspace_ptr, void *output_size_ptr, - void *output_ptr) { - switch (k_type) { - default: { return; } - case CNRT_FUNC_TYPE_BLOCK: - case CNRT_FUNC_TYPE_UNION1: { - MLUUnion1KernelNMS<<>>( - (void *)boxes_ptr, (void *)scores_ptr, input_num_boxes, - max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, - /*output_mode=*/0, workspace_ptr, output_size_ptr, output_ptr, - data_type_input, offset, /*algo=*/1); - }; break; - case CNRT_FUNC_TYPE_UNION2: - case CNRT_FUNC_TYPE_UNION4: - case CNRT_FUNC_TYPE_UNION8: - case CNRT_FUNC_TYPE_UNION16: { - MLUUionXKernelNMS<<>>( - (void *)boxes_ptr, (void *)scores_ptr, input_num_boxes, - max_output_boxes, iou_threshold, /*confidence_threshold=*/0.0, offset, - data_type_input, /*output_mode=*/0, /*algo=*/1, workspace_ptr, - output_size_ptr, output_ptr); - }; break; - } -} diff --git a/mmcv/ops/csrc/common/mlu/nms_utils.hpp b/mmcv/ops/csrc/common/mlu/nms_utils.hpp deleted file mode 100644 index 61f5ba95df..0000000000 --- a/mmcv/ops/csrc/common/mlu/nms_utils.hpp +++ /dev/null @@ -1,553 +0,0 @@ -/************************************************************************* - * Copyright (C) [2019-2022] by Cambricon, Inc. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#ifndef NMS_UTILS_HPP_ -#define NMS_UTILS_HPP_ -#include "common_mlu_helper.hpp" - -#define NMS_SIZE (64) -#define NMS_UP(x, y) (x / y + (int)(x % y > 0)) * y -#define NMS_DOWN(x, y) (x / y) * y -#define INFO_NUM (5) // 5 means x1, x2, y1, y2 and score -#define MEMORY_CORE (0x80) -#define REDUCE_NUM \ - (7) // score, x1, y1, x2, y2, max_index (reserve 2 num for half-type input) - -__mlu_func__ void pvLock() { -#if __BANG_ARCH__ == 270 - if (coreId != MEMORY_CORE) { - __bang_lock(0, 0); - } -#endif -} - -__mlu_func__ void pvUnlock() { -#if __BANG_ARCH__ == 270 - if (coreId != MEMORY_CORE) { - __bang_unlock(0, 0); - } -#endif -} - -template -static __mlu_func__ void computeReluN(T *nram_dst, T *nram_src, void *nram_tmp, - const int deal_num, - const T threshold = 0) { - if (threshold < 0) { - return; - } - if (threshold) { -#if __BANG_ARCH__ >= 300 - __bang_relun(nram_dst, nram_src, deal_num, threshold); -#else - int align_num = NFU_ALIGN_SIZE / sizeof(T); - T *nram_aux_a = (T *)nram_tmp; - T *nram_aux_b = nram_aux_a + deal_num; - T *nram_zero = nram_aux_b + align_num; - __bang_write_value(nram_aux_b, align_num, threshold); - __bang_write_zero(nram_zero, align_num); - __bang_cycle_lt((T *)nram_aux_a, nram_src, (T *)nram_aux_b, deal_num, - align_num); - __bang_mul(nram_dst, nram_src, (T *)nram_aux_a, deal_num); - __bang_cycle_eq((T *)nram_aux_a, (T *)nram_aux_a, (T *)nram_zero, deal_num, - align_num); - __bang_cycle_mul((T *)nram_aux_a, (T *)nram_aux_a, (T *)nram_aux_b, - deal_num, align_num); - __bang_add(nram_dst, nram_dst, (T *)nram_aux_a, deal_num); - __bang_cycle_gt((T *)nram_aux_a, nram_dst, (T *)nram_zero, deal_num, - align_num); - __bang_mul(nram_dst, nram_dst, (T *)nram_aux_a, deal_num); -#endif - } else { -#if __BANG_ARCH__ >= 300 - __bang_relu(nram_dst, nram_src, deal_num); -#else - __bang_active_relu(nram_dst, nram_src, deal_num); -#endif - } -} - -__mlu_func__ void getComputeParamsBlockOrU1( - const int input_dwidth, const int input_box_num, const int limit, - const int core_limit, int &input_offset, int &max_seg_pad, int &repeat, - int &remain, int &remain_pad, int &max_seg_iou_compute, - int &repeat_iou_compute, int &remain_iou_compute, - int &remain_pad_iou_compute) { - int avg_core = input_box_num / core_limit; - int rem = input_box_num % core_limit; - int len_core = avg_core + (coreId < rem ? 1 : 0); - input_offset = avg_core * coreId + (coreId <= rem ? coreId : rem); - max_seg_pad = NMS_DOWN(limit, NMS_SIZE); - repeat = len_core / max_seg_pad; - remain = len_core % max_seg_pad; - remain_pad = NMS_UP(remain, NMS_SIZE); - - // if datatype is fp16, we should cvt to fp32 when compute iou - max_seg_iou_compute = NMS_DOWN(max_seg_pad / (4 / input_dwidth), NMS_SIZE); - repeat_iou_compute = len_core / max_seg_iou_compute; - remain_iou_compute = len_core % max_seg_iou_compute; - remain_pad_iou_compute = NMS_UP(remain_iou_compute, NMS_SIZE); -} - -__mlu_func__ void getComputeParamsUx( - const int input_dwidth, const int input_num_boxes, const int limit, - int &input_offset, int &max_seg_pad, int &repeat, int &remain, - int &remain_pad, int &max_seg_iou_compute, int &repeat_iou_compute, - int &remain_iou_compute, int &remain_pad_iou_compute) { - // data split - int avg_cluster = input_num_boxes / clusterDim; - int rem_cluster = input_num_boxes % clusterDim; - int len_cluster = avg_cluster + (clusterId < rem_cluster); - int cluster_offset = avg_cluster * clusterId + - (clusterId <= rem_cluster ? clusterId : rem_cluster); - - int avg_core = len_cluster / coreDim; - int rem_core = len_cluster % coreDim; - int len_core = avg_core + (coreId < rem_core); - int core_offset = - avg_core * coreId + (coreId <= rem_core ? coreId : rem_core); - input_offset = cluster_offset + core_offset; - - max_seg_pad = NMS_DOWN(limit, NMS_SIZE); - - // core 0 of each cluster calculate the max score index - int max_index_len_core = avg_cluster + (clusterId < rem_cluster); - repeat = max_index_len_core / max_seg_pad; - remain = max_index_len_core % max_seg_pad; - remain_pad = NMS_UP(remain, NMS_SIZE); - // if datatype is fp16, we should cvt to fp32 when compute iou - max_seg_iou_compute = - NMS_DOWN(max_seg_pad / (sizeof(float) / input_dwidth), NMS_SIZE); - repeat_iou_compute = len_core / max_seg_iou_compute; - remain_iou_compute = len_core % max_seg_iou_compute; - remain_pad_iou_compute = NMS_UP(remain_iou_compute, NMS_SIZE); -} - -template -__mlu_func__ void findGlobalMaxBox(IN_DT *max_box, IN_DT *sram, - IN_DT *inter_x1) { - // copy all partial max to the sram of cluster 0 - if (clusterId != 0) { - __memcpy(sram + REDUCE_NUM * clusterId, sram, REDUCE_NUM * sizeof(IN_DT), - SRAM2SRAM, 0); - } - __sync_all(); - - // reduce between clusters to get the global max box - if (clusterId == 0) { - if (coreId == 0) { - __bang_write_zero(inter_x1, NMS_SIZE); - __memcpy(inter_x1, sram, sizeof(IN_DT), SRAM2NRAM, sizeof(IN_DT), - REDUCE_NUM * sizeof(IN_DT), clusterDim - 1); - __bang_max(max_box, inter_x1, NMS_SIZE); - int max_cluster = (sizeof(IN_DT) == sizeof(half)) - ? ((uint16_t *)max_box)[1] - : ((uint32_t *)max_box)[1]; - __memcpy(max_box, sram + max_cluster * REDUCE_NUM, - REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); - __memcpy(sram, max_box, REDUCE_NUM * sizeof(IN_DT), NRAM2SRAM); - } - __sync_cluster(); - if (coreId == 0x80 && clusterDim > 1) { - // broadcast global max box to each cluster's sram - for (int cluster_idx = 1; cluster_idx < clusterDim; ++cluster_idx) { - __memcpy(sram, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2SRAM, - cluster_idx); - } - } - __sync_cluster(); - } - __sync_all(); - - // copy the global max box to max_box - __memcpy(max_box, sram, REDUCE_NUM * sizeof(IN_DT), SRAM2NRAM); -} - -template -__mlu_func__ void findCoreMaxBox( - IN_DT *input_score_ptr, IN_DT *score, IN_DT *inter_x1, IN_DT *max_box, - const IN_DT *input_x1_ptr, const IN_DT *input_y1_ptr, - const IN_DT *input_x2_ptr, const IN_DT *input_y2_ptr, - const mluMemcpyDirection_t load_dir, const int input_offset, - const int repeat, const int remain, const int remain_pad, - const int max_seg_pad, int &max_index) { - if (coreId != 0x80) { - for (int i = 0; i <= repeat; i++) { - if (i == repeat && remain == 0) { - break; - } - int seg_len = 0; // the length every nms compute - int cpy_len = 0; // the length every nms memcpy - i == repeat ? seg_len = remain_pad : seg_len = max_seg_pad; - i == repeat ? cpy_len = remain : cpy_len = max_seg_pad; - /******NMS LOAD START******/ - __bang_write_zero(score, seg_len); - __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, - cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), - cpy_len * sizeof(IN_DT), 0); - - /******NMS LOAD END******/ - - __bang_max(inter_x1, score, seg_len); - if (inter_x1[0] > max_box[0]) { - max_box[0] = inter_x1[0]; - if (sizeof(IN_DT) == sizeof(half)) { - max_index = ((uint16_t *)inter_x1)[1] + input_offset + - i * max_seg_pad; // offset start from head of input_data - } else if (sizeof(IN_DT) == sizeof(float)) { - max_index = ((uint32_t *)inter_x1)[1] + input_offset + - i * max_seg_pad; // offset start from head of input_data - } - } - } // for repeat - // the max box's x1, y1, x2, y2 on every core - max_box[1] = input_x1_ptr[max_index]; - max_box[2] = input_y1_ptr[max_index]; - max_box[3] = input_x2_ptr[max_index]; - max_box[4] = input_y2_ptr[max_index]; - ((uint32_t *)(max_box + 5))[0] = max_index; - } -} - -template -__mlu_func__ void findClusterMaxBox(IN_DT *sram, IN_DT *max_box, - IN_DT *inter_x1, IN_DT *input_data_score, - const int core_limit) { - // find the max with sram - // copy every core's box info to sram, form: score---x1---y1---x2---y2--- - __memcpy(sram + REDUCE_NUM * coreId, max_box, REDUCE_NUM * sizeof(IN_DT), - NRAM2SRAM); // int32_t datatype - __sync_cluster(); - - // copy score from sram to nram and find the max - __bang_write_zero(inter_x1, 64); - __memcpy(inter_x1, sram, sizeof(IN_DT), SRAM2NRAM, sizeof(IN_DT), - REDUCE_NUM * sizeof(IN_DT), coreDim - 1); - __bang_max(max_box, inter_x1, 64); - int max_core = sizeof(IN_DT) == sizeof(half) ? ((uint16_t *)max_box)[1] - : ((uint32_t *)max_box)[1]; - // copy the max box to max_box - __memcpy(max_box, sram + max_core * REDUCE_NUM, REDUCE_NUM * sizeof(IN_DT), - SRAM2NRAM); -} - -/*****************************************************************************/ -/*******************************CALCULATE MAX AREA****************************/ -/*****************************************************************************/ - -template -__mlu_func__ void calMaxArea(IN_DT *max_box, const int algo, float offset, - float &max_area) { - if (algo == 0 || offset == 0.0) { - max_area = ((float)max_box[3] - (float)max_box[1]) * - ((float)max_box[4] - (float)max_box[2]); - } else { - max_area = ((float)max_box[3] - (float)max_box[1] + offset) * - ((float)max_box[4] - (float)max_box[2] + offset); - } -} - -template -__mlu_func__ void calMaxArea(IN_DT *max_box, const int algo, float offset, - float &max_area, float &max_box_x1, - float &max_box_y1, float &max_box_x2, - float &max_box_y2) { - // the case of random inf will break the requirement of x1<=x2, y1<=y2 - // so exchange it if it happens. - max_box_x1 = float(max_box[1]); - max_box_x2 = float(max_box[3]); - if (max_box[1] > max_box[3]) { - max_box_x1 = float(max_box[3]); - max_box_x2 = float(max_box[1]); - } - max_box_y1 = float(max_box[2]); - max_box_y2 = float(max_box[4]); - if (max_box[2] > max_box[4]) { - max_box_y1 = float(max_box[4]); - max_box_y2 = float(max_box[2]); - } - if (algo == 0 || offset == 0.0) { - max_area = (max_box_x2 - max_box_x1) * (max_box_y2 - max_box_y1); - } else { - max_area = - (max_box_x2 - max_box_x1 + offset) * (max_box_y2 - max_box_y1 + offset); - } -} - -/***********************************************************************/ -/*******************************STORE RESULT****************************/ -/***********************************************************************/ -template -__mlu_func__ void storeResult(IN_DT *max_box, OUT_DT *nram_save, - OUT_DT *&output_dram, const int keep, - const int nram_save_limit_count, - const int max_output_size, - const float thresh_score, const int output_mode, - int &nram_save_count, uint32_t &output_box_num) { - /******NMS STORE START******/ - // store to nram - if (float(max_box[0]) > thresh_score) { - OUT_DT *save_ptr; - int save_offset = 0; - int save_str_num = 0; - save_ptr = nram_save; - save_offset = nram_save_count; - save_str_num = nram_save_limit_count; - if (clusterId == 0 && coreId == 0) { - if (output_mode == 0) { // index1, index2, ... - save_ptr[save_offset] = ((uint32_t *)(max_box + INFO_NUM))[0]; - } else if (output_mode == 1) { // score, x1, y1, x2, y2 - __memcpy(save_ptr + save_offset * INFO_NUM, max_box, - INFO_NUM * sizeof(IN_DT), NRAM2NRAM, INFO_NUM * sizeof(IN_DT), - INFO_NUM * sizeof(IN_DT), 0); - } else if (output_mode == 2) { // score---, x1---, y1---, x2---, y2--- - __memcpy(save_ptr + save_offset, max_box, 1 * sizeof(IN_DT), NRAM2NRAM, - save_str_num * sizeof(IN_DT), 1 * sizeof(IN_DT), 4); - } - } - nram_save_count++; - output_box_num++; - } - - // store to sram/gdram - if (output_box_num != 0) { - if ((nram_save_count == nram_save_limit_count) || - (float(max_box[0]) <= thresh_score) || keep == max_output_size - 1) { - if (nram_save_count != 0) { - if (clusterId == 0 && coreId == 0) { - if (output_mode == 0) { // index1, index2, ... - pvLock(); - __memcpy(output_dram, nram_save, nram_save_count * sizeof(uint32_t), - NRAM2GDRAM); - pvUnlock(); - output_dram += nram_save_count; - } else if (output_mode == 1) { // score, x1, y1, x2, y2 - pvLock(); - __memcpy(output_dram, nram_save, - nram_save_count * INFO_NUM * sizeof(IN_DT), NRAM2GDRAM); - pvUnlock(); - output_dram += nram_save_count * INFO_NUM; - } else if (output_mode == - 2) { // score---, x1---, y1---, x2---, y2--- - pvLock(); - __memcpy(output_dram, nram_save, nram_save_count * sizeof(IN_DT), - NRAM2GDRAM, max_output_size * sizeof(IN_DT), - nram_save_limit_count * sizeof(IN_DT), 4); - pvUnlock(); - output_dram += nram_save_count; - } - nram_save_count = 0; - } - } - } // if move data nram->sram/gdram - } // if dst -} - -template -__mlu_func__ void scoreUpdate( - IN_DT *input_score_ptr, const mluMemcpyDirection_t load_dir, - const mluMemcpyDirection_t store_dir, const IN_DT *input_x1_ptr, - const IN_DT *input_y1_ptr, const IN_DT *input_x2_ptr, - const IN_DT *input_y2_ptr, IN_DT *x1, IN_DT *y1, IN_DT *x2, IN_DT *y2, - IN_DT *score, IN_DT *inter_x1, IN_DT *inter_y1, IN_DT *inter_x2, - IN_DT *inter_y2, IN_DT *max_box, const float max_box_x1, - const float max_box_y1, const float max_box_x2, const float max_box_y2, - OUT_DT *nram_save, int repeat_iou_compute, int remain_iou_compute, - int remain_pad_iou_compute, int max_seg_iou_compute, int max_seg_pad, - const float thresh_iou, const float div_thresh_iou, const int input_offset, - const float offset, const float max_area, const int input_num_boxes, - const int algo) { - for (int i = 0; i <= repeat_iou_compute; i++) { - if (i == repeat_iou_compute && remain_iou_compute == 0) { - break; - } - int seg_len = (i == repeat_iou_compute) ? remain_pad_iou_compute - : max_seg_iou_compute; - int cpy_len = - (i == repeat_iou_compute) ? remain_iou_compute : max_seg_iou_compute; - /******NMS LOAD START******/ - int dt_offset = 0; - if (sizeof(IN_DT) == sizeof(float)) { - __memcpy(score, input_score_ptr + input_offset + i * max_seg_pad, - cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), - cpy_len * sizeof(IN_DT), 0); - dt_offset = 0; - } else if (sizeof(IN_DT) == sizeof(half)) { - __memcpy(x1, input_score_ptr + input_offset + i * max_seg_iou_compute, - cpy_len * sizeof(IN_DT), load_dir, cpy_len * sizeof(IN_DT), - cpy_len * sizeof(IN_DT), 0); - __bang_half2float((float *)score, (half *)x1, seg_len); - dt_offset = max_seg_iou_compute; - } -#if __BANG_ARCH__ >= 300 - __memcpy(inter_x1 + dt_offset, - input_x1_ptr + input_offset + i * max_seg_iou_compute, - cpy_len * sizeof(IN_DT), load_dir, max_seg_pad * sizeof(IN_DT), - input_num_boxes * sizeof(IN_DT), 3); - - if (sizeof(IN_DT) == sizeof(half)) { - __bang_half2float((float *)inter_x1, - (half *)inter_x1 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)inter_y1, - (half *)inter_y1 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)inter_x2, - (half *)inter_x2 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)inter_y2, - (half *)inter_y2 + max_seg_iou_compute, seg_len); - } - // box transfer - __bang_minequal((float *)x1, (float *)inter_x1, (float *)inter_x2, seg_len); - __bang_maxequal((float *)x2, (float *)inter_x1, (float *)inter_x2, seg_len); - __bang_minequal((float *)y1, (float *)inter_y1, (float *)inter_y2, seg_len); - __bang_maxequal((float *)y2, (float *)inter_y1, (float *)inter_y2, seg_len); - // 1、 compute IOU - // get the area_I - __bang_maxeq_scalar((float *)inter_x1, (float *)x1, max_box_x1, - seg_len); // inter_x1 - __bang_mineq_scalar((float *)inter_x2, (float *)x2, max_box_x2, - seg_len); // inter_x2 - __bang_sub((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, - seg_len); - if (algo == 1 && offset != 0.0) { - __bang_add_scalar((float *)inter_x1, (float *)inter_x1, offset, seg_len); - } - computeReluN((float *)inter_x1, (float *)inter_x1, NULL, - seg_len); // inter_w - __bang_maxeq_scalar((float *)inter_y1, (float *)y1, float(max_box_y1), - seg_len); // inter_y1 - __bang_mineq_scalar((float *)inter_y2, (float *)y2, float(max_box_y2), - seg_len); // inter_y2 - __bang_sub((float *)inter_y1, (float *)inter_y2, (float *)inter_y1, - seg_len); - if (algo == 1 && offset != 0.0) { - __bang_add_scalar((float *)inter_y1, (float *)inter_y1, offset, seg_len); - } - computeReluN((float *)inter_y1, (float *)inter_y1, NULL, - seg_len); // inter_h - __bang_mul((float *)inter_x1, (float *)inter_x1, (float *)inter_y1, - seg_len); // area_I - // get the area of input_box: area = (x2 - x1) * (y2 - y1); - if (algo == 1 && offset != 0.0) { - __bang_fusion(FUSION_FSA, (float *)inter_y1, (float *)x2, (float *)x1, - offset, seg_len, seg_len); - __bang_fusion(FUSION_FSA, (float *)inter_y2, (float *)y2, (float *)y1, - offset, seg_len, seg_len); - __bang_mul((float *)inter_x2, (float *)inter_y1, (float *)inter_y2, - seg_len); // area - } else { - __bang_sub((float *)inter_y1, (float *)x2, (float *)x1, seg_len); - __bang_fusion(FUSION_FSM, (float *)inter_x2, (float *)y2, (float *)y1, - (float *)inter_y1, seg_len, seg_len); - } - // get the area_U: area + max_area - area_I - __bang_fusion(FUSION_FAS, (float *)inter_x2, (float *)inter_x2, max_area, - (float *)inter_x1, seg_len, seg_len); - // 2、 select the box - // if IOU greater than thres, set the score to zero, abort it: area_U > - // area_I * (1 / thresh)? - if (thresh_iou > 0.0) { - __bang_mul_scalar((float *)inter_x1, (float *)inter_x1, div_thresh_iou, - seg_len); - } else { - __bang_mul_scalar((float *)inter_x2, (float *)inter_x2, thresh_iou, - seg_len); - } - // process for nan - __bang_lt((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, seg_len); - __bang_not((float *)inter_x1, (float *)inter_x1, seg_len); - __bang_mul((float *)score, (float *)score, (float *)inter_x1, seg_len); -/******NMS COMPUTE END******/ -#else - __memcpy(x1 + dt_offset, - input_x1_ptr + input_offset + i * max_seg_iou_compute, - cpy_len * sizeof(IN_DT), load_dir, max_seg_pad * sizeof(IN_DT), - input_num_boxes * sizeof(IN_DT), 3); - if (sizeof(IN_DT) == sizeof(half)) { - __bang_half2float((float *)x1, (half *)x1 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)y1, (half *)y1 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)x2, (half *)x2 + max_seg_iou_compute, seg_len); - __bang_half2float((float *)y2, (half *)y2 + max_seg_iou_compute, seg_len); - } - // 1、 compute IOU - // get the area_I - __bang_write_value((float *)inter_y1, seg_len, - float(max_box[1])); // max_x1 - __bang_maxequal((float *)inter_x1, (float *)x1, (float *)inter_y1, - seg_len); // inter_x1 - __bang_write_value((float *)inter_y2, seg_len, - float(max_box[3])); // max_x2 - __bang_minequal((float *)inter_x2, (float *)x2, (float *)inter_y2, - seg_len); // inter_x2 - __bang_sub((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, - seg_len); - if (algo == 1 && offset != 0.0) { - __bang_add_scalar((float *)inter_x1, (float *)inter_x1, offset, seg_len); - } - computeReluN((float *)inter_x1, (float *)inter_x1, NULL, - seg_len); // inter_w - __bang_write_value((float *)inter_x2, seg_len, - float(max_box[2])); // max_y1 - __bang_maxequal((float *)inter_y1, (float *)y1, (float *)inter_x2, - seg_len); // inter_y1 - __bang_write_value((float *)inter_x2, seg_len, - float(max_box[4])); // max_y2 - __bang_minequal((float *)inter_y2, (float *)y2, (float *)inter_x2, - seg_len); // inter_y2 - __bang_sub((float *)inter_y1, (float *)inter_y2, (float *)inter_y1, - seg_len); - if (algo == 1 && offset != 0.0) { - __bang_add_scalar((float *)inter_y1, (float *)inter_y1, offset, seg_len); - } - computeReluN((float *)inter_y1, (float *)inter_y1, NULL, - seg_len); // inter_h - __bang_mul((float *)inter_x1, (float *)inter_x1, (float *)inter_y1, - seg_len); // area_I - // get the area of input_box: area = (x2 - x1) * (y2 - y1); - __bang_sub((float *)inter_y1, (float *)x2, (float *)x1, seg_len); - __bang_sub((float *)inter_y2, (float *)y2, (float *)y1, seg_len); - if (algo == 1 && offset != 0.0) { - __bang_add_scalar((float *)inter_y1, (float *)inter_y1, offset, seg_len); - __bang_add_scalar((float *)inter_y2, (float *)inter_y2, offset, seg_len); - } - __bang_mul((float *)inter_x2, (float *)inter_y1, (float *)inter_y2, - seg_len); // area - // get the area_U: area + max_area - area_I - __bang_add_scalar((float *)inter_x2, (float *)inter_x2, float(max_area), - seg_len); - __bang_sub((float *)inter_x2, (float *)inter_x2, (float *)inter_x1, - seg_len); // area_U - // 2、 select the box - // if IOU greater than thresh, set the score to zero, abort it: area_U > - // area_I * (1 / thresh)? - if (thresh_iou > 0.0) { - __bang_mul_scalar((float *)inter_x1, (float *)inter_x1, div_thresh_iou, - seg_len); - } else { - __bang_mul_scalar((float *)inter_x2, (float *)inter_x2, thresh_iou, - seg_len); - } - __bang_ge((float *)inter_x1, (float *)inter_x2, (float *)inter_x1, seg_len); - __bang_mul((float *)score, (float *)score, (float *)inter_x1, seg_len); -/******NMS COMPUTE END******/ -#endif - // update the score - if (sizeof(IN_DT) == sizeof(half)) { - convertFloat2half((half *)score, (float *)score, seg_len); - } - pvLock(); - __memcpy(input_score_ptr + input_offset + i * max_seg_iou_compute, score, - cpy_len * sizeof(IN_DT), store_dir, cpy_len * sizeof(IN_DT), - cpy_len * sizeof(IN_DT), 0); - pvUnlock(); - } -} - -#endif // NMS_UTILS_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu deleted file mode 100644 index c99176ab20..0000000000 --- a/mmcv/ops/csrc/common/mlu/roi_align_mlu_kernel.mlu +++ /dev/null @@ -1,493 +0,0 @@ -/************************************************************************* - * Copyright (C) 2021 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "common_mlu_helper.hpp" - -#define ROI_OFFSET 5 - -__nram__ char buffer[MAX_NRAM_SIZE]; - -namespace forward { -template -__mlu_func__ void bilinearInterpolate(const int input_height, - const int input_width, T y, T x, T *w1, - T *w2, T *w3, T *w4, int *x_low, - int *x_high, int *y_low, int *y_high, - bool *empty) { - // deal with cases that inverse elements are of feature map boundary - if (y < -1.0 || y > input_height || x < -1.0 || x > input_width) { - *empty = true; - return; - } - - if (y <= 0) y = 0; - if (x <= 0) x = 0; - - int y_low_ = int(y); - int x_low_ = int(x); - - if (y_low_ >= input_height - 1) { - *y_high = y_low_ = input_height - 1; - y = (T)y_low_; - } else { - *y_high = y_low_ + 1; - } - - if (x_low_ >= input_width - 1) { - *x_high = x_low_ = input_width - 1; - x = T(x_low_); - } else { - *x_high = x_low_ + 1; - } - - *y_low = y_low_; - *x_low = x_low_; - - T ly = y - y_low_; - T lx = x - x_low_; - T hy = 1.0 - ly; - T hx = 1.0 - lx; - *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; - return; -} - -template -__mlu_func__ void computeChannel(T *input_core, T *nram_in, T *output_core, - T *nram_out, const int roi_bin_grid_h, - const int roi_bin_grid_w, const T roi_start_h, - const T roi_start_w, const int ph, - const int pw, const T bin_size_h, - const T bin_size_w, const float count, - const int input_height, const int input_width, - const int channels, const int cyc_num, - const int max_elements) { - int cyc_channel = max_elements; - - for (int i = 0; i < cyc_num; i++) { - int real_channel = - (i == cyc_num - 1) ? channels - i * cyc_channel : cyc_channel; - int align_channel = PAD_UP(real_channel, NFU_ALIGN_SIZE / sizeof(T)); - __bang_write_zero(nram_out, align_channel); - uint32_t real_size = real_channel * sizeof(T); - - int iy, ix; - for (iy = 0; iy < roi_bin_grid_h; iy++) { - // 1. compute the coordinates of the y axis in the current roi_bin_grid_h - T y = roi_start_h + ph * bin_size_h + - (T)(iy + 0.5) * bin_size_h / (T)(roi_bin_grid_h); - for (ix = 0; ix < roi_bin_grid_w; ix++) { - // 2. compute the coordinates of the x axis in the current - // roi_bin_grid_w - T x = roi_start_w + pw * bin_size_w + - (T)(ix + 0.5) * bin_size_w / (T)(roi_bin_grid_w); - - // 3. compute the four weights (w1, w2, w3 and w4), the height (y_low - // and y_high) and weight (x_low and x_high) of input feature map in - // the current roi bin grid, and the flag (empty) which shows if x, y - // are out of input feature map ranges - T w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bool empty = false; - - bilinearInterpolate(input_height, input_width, y, x, &w1, &w2, &w3, &w4, - &x_low, &x_high, &y_low, &y_high, &empty); - - // 4. compute interpolation of the current roi bin grid - // tmp_cyc1, temp_cyc2, tmp_cyc3 and tmp_cyc4 store the input values - // to compute the interpolation, and then reused to compute - // the argmax_x and argmax_y. - T *tmp_cyc1 = nram_in + cyc_channel; - T *tmp_cyc2 = nram_in + cyc_channel * 2; - T *tmp_cyc3 = nram_in + cyc_channel * 3; - T *tmp_cyc4 = nram_in + cyc_channel * 4; - - if (empty) { // exits abnormal values - __bang_write_zero(nram_in, align_channel); - } else { - __bang_write_zero(nram_in, align_channel); - uint32_t offset1 = (y_low * input_width + x_low) * channels; - uint32_t offset2 = (y_low * input_width + x_high) * channels; - uint32_t offset3 = (y_high * input_width + x_low) * channels; - uint32_t offset4 = (y_high * input_width + x_high) * channels; - T *input1 = (T *)input_core + offset1 + i * cyc_channel; - T *input2 = (T *)input_core + offset2 + i * cyc_channel; - T *input3 = (T *)input_core + offset3 + i * cyc_channel; - T *input4 = (T *)input_core + offset4 + i * cyc_channel; - - // load the four pixels (p1, p2, p3 and p4) of input feature map to - // compute interpolation - __memcpy(tmp_cyc1, input1, real_size, GDRAM2NRAM); - __memcpy(tmp_cyc2, input2, real_size, GDRAM2NRAM); - __memcpy(tmp_cyc3, input3, real_size, GDRAM2NRAM); - __memcpy(tmp_cyc4, input4, real_size, GDRAM2NRAM); - - // interpolation value = w1 * p1 + w2 * p2 + w3 * p3 + w4 * p4 - __bang_mul_scalar(tmp_cyc1, tmp_cyc1, w1, align_channel); - __bang_mul_scalar(tmp_cyc2, tmp_cyc2, w2, align_channel); - __bang_mul_scalar(tmp_cyc3, tmp_cyc3, w3, align_channel); - __bang_mul_scalar(tmp_cyc4, tmp_cyc4, w4, align_channel); - - __bang_add(nram_in, tmp_cyc1, nram_in, align_channel); - __bang_add(nram_in, tmp_cyc2, nram_in, align_channel); - __bang_add(nram_in, tmp_cyc3, nram_in, align_channel); - __bang_add(nram_in, tmp_cyc4, nram_in, align_channel); - } - // 5. compute sum value and corresponding coordinates of x axis and y - // axis. Update the sum value. - __bang_add(nram_out, nram_in, nram_out, align_channel); - } // loop_roi_grid_w - } // loop_roi_grid_h - T count_value = (T)(1.0 / count); - __bang_mul_scalar(nram_out, nram_out, count_value, align_channel); - __memcpy(output_core + i * cyc_channel, nram_out, real_size, NRAM2GDRAM); - } // loop_cyc_num -} - -template -__mlu_func__ void roialignForwardAvg( - T *input, T *rois, T *output, const bool aligned, const int channels, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const T spatial_scale, - const int num_rois) { - // find limit for channel, the nram space is divided to 6 parts that are - // input, 4 weights to compute the interpolation (w1, w2, w3, w4), output - - // max_elements : 300 : float datatype : 27296, half datatype : 54592 - // max_elements : 200 : float datatype : 16384, half datatype : 32768 - int max_elements = (PAD_DOWN(MAX_NRAM_SIZE / 6, NFU_ALIGN_SIZE)) / sizeof(T); - int cyc_num = channels / max_elements + (int)(channels % max_elements != 0); - T offset = aligned ? (T)0.5 : (T)0.0; - int task_num = num_rois * pooled_height * pooled_width; - T *nram_out = (T *)buffer; - T *nram_in = nram_out + max_elements; - if (task_num < taskDim) { - if (taskId >= task_num) { - return; - } - } - - for (int bin_idx = taskId; bin_idx < task_num; bin_idx = bin_idx + taskDim) { - if (bin_idx >= task_num) { - return; - } - - // (n,ph.pw) is a c in the pooled output - int pw = bin_idx % pooled_width; - int ph = (bin_idx / pooled_width) % pooled_height; - int n = bin_idx / pooled_width / pooled_height; - - T *roi_id_tmp = rois + n * ROI_OFFSET; - // 1. compute width and height of roi region. - int batch_idx = (int)roi_id_tmp[0]; - T roi_x1 = roi_id_tmp[1]; - T roi_y1 = roi_id_tmp[2]; - T roi_x2 = roi_id_tmp[3]; - T roi_y2 = roi_id_tmp[4]; - T roi_start_w = roi_x1 * spatial_scale - offset; - T roi_start_h = roi_y1 * spatial_scale - offset; - T roi_end_w = roi_x2 * spatial_scale - offset; - T roi_end_h = roi_y2 * spatial_scale - offset; - T roi_width = roi_end_w - roi_start_w; - T roi_height = roi_end_h - roi_start_h; - - if (!aligned) { - roi_width = roi_width > (T)(1.0) ? roi_width : (T)(1.0); - roi_height = roi_height > (T)(1.0) ? roi_height : (T)(1.0); - } - - // 2. compute float-type width and height of roi bin region. - T bin_size_w = (T)roi_width / (T)pooled_width; - T bin_size_h = (T)roi_height / (T)pooled_height; - - // 3. compute int-type width and height of roi bin region. - int roi_bin_grid_h, roi_bin_grid_w; - roi_bin_grid_h = (sampling_ratio > 0) - ? sampling_ratio - : int(ceilf(roi_height / pooled_height)); - roi_bin_grid_w = (sampling_ratio > 0) - ? sampling_ratio - : int(ceilf(roi_width / pooled_width)); - float count = (float)((roi_bin_grid_h * roi_bin_grid_w) > 1 - ? roi_bin_grid_h * roi_bin_grid_w - : 1.0); - T *input_core = input + batch_idx * channels * input_width * input_height; - T *output_core = output + bin_idx * channels; - // 4. compute avg value and corresponding coordinates of x axis and y axis. - computeChannel(input_core, nram_in, output_core, nram_out, roi_bin_grid_h, - roi_bin_grid_w, roi_start_h, roi_start_w, ph, pw, bin_size_h, - bin_size_w, count, input_height, input_width, channels, - cyc_num, max_elements); - } -} - -__mlu_global__ void MLUUnion1KernelRoiAlignAvg( - const void *input, const void *rois, const int channels, const bool aligned, - const int pooled_height, const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, const float spatial_scale, - const int num_rois, const cnrtDataType_t data_type, void *output) { - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - - switch (data_type) { - case CNRT_FLOAT16: { - roialignForwardAvg((half *)input, (half *)rois, (half *)output, aligned, - channels, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, (half)spatial_scale, - num_rois); - }; break; - case CNRT_FLOAT32: { - roialignForwardAvg((float *)input, (float *)rois, (float *)output, - aligned, channels, pooled_height, pooled_width, - input_height, input_width, sampling_ratio, - (float)spatial_scale, num_rois); - }; break; - default: - break; - } - - return; -} -} // namespace forward - -namespace backward { -__mlu_func__ void bilinearInterpolateGradient(int height, int width, float y, - float x, float *w1, float *w2, - float *w3, float *w4, int *x_low, - int *x_high, int *y_low, - int *y_high) { - if (y < -1.0 || y > height || x < -1.0 || x > width) { - *w1 = 0.0, *w2 = 0.0, *w3 = 0.0, *w4 = 0.0; - *x_low = -1, *x_high = -1, *y_low = -1, *y_high = -1; - return; - } - if (y <= 0) { - y = 0; - } - if (x <= 0) { - x = 0; - } - *y_low = (int)y; - *x_low = (int)x; - if (*y_low >= height - 1) { - *y_high = height - 1, *y_low = height - 1; - y = (float)(*y_low); - } else { - *y_high = *y_low + 1; - } - if (*x_low >= width - 1) { - *x_high = width - 1, *x_low = width - 1; - x = (float)(*x_low); - } else { - *x_high = *x_low + 1; - } - float ly = y - *y_low, lx = x - *x_low; - float hy = 1.0 - ly, hx = 1.0 - lx; - *w1 = hy * hx, *w2 = hy * lx, *w3 = ly * hx, *w4 = ly * lx; - return; -} - -template -__mlu_func__ void unionRoiAlignBp( - T *grads, T *boxes, T *grads_image, const int boxes_num, const int hi, - const int wi, const int c, const int no, const int ho, const int wo, - const float spatial_scale, const int sampling_ratio, const bool aligned) { - int c_align = PAD_UP(c, NFU_ALIGN_SIZE / sizeof(T)); - int deal_all = boxes_num * hi * wi; - int deal_this_core = deal_all / taskDim + (int)(taskId < deal_all % taskDim); - for (int i = 0; i < deal_this_core; ++i) { - int bhw_id = i * taskDim + taskId; - int box_id = bhw_id / (hi * wi); - int ih = (bhw_id / wi) % hi; - int iw = bhw_id % wi; - T *box = boxes + box_id * 5; - int image_id = (int)box[0]; - T *image_offset = grads_image + image_id * ho * wo * c; - T *grads_ = grads + box_id * hi * wi * c + ih * wi * c + iw * c; - - float offset = aligned ? 0.5 : 0.0; - float x1 = box[1] * spatial_scale - offset; - float y1 = box[2] * spatial_scale - offset; - float x2 = box[3] * spatial_scale - offset; - float y2 = box[4] * spatial_scale - offset; - float roi_width = x2 - x1; - float roi_height = y2 - y1; - if (!aligned) { - roi_width = (roi_width > 1.0) ? roi_width : 1.0; - roi_height = (roi_height > 1.0) ? roi_height : 1.0; - } - float bin_size_h = roi_height / hi; - float bin_size_w = roi_width / wi; - - int roi_grid_h = - (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_height / hi); - int roi_grid_w = - (sampling_ratio > 0) ? sampling_ratio : std::ceil(roi_width / wi); - const T count = roi_grid_h * roi_grid_w; - if (c_align * sizeof(T) * 2 <= MAX_NRAM_SIZE) { - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, - &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - __memcpy(buffer, grads_, c * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w1, - c_align); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align, - 1 / count, c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w2, - c_align); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align, - 1 / count, c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_low * wo * c + x_high * c, - (T *)buffer + c_align, c); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w3, - c_align); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align, - 1 / count, c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_low * c, - (T *)buffer + c_align, c); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer, (T)w4, - c_align); - __bang_mul_scalar((T *)buffer + c_align, (T *)buffer + c_align, - 1 / count, c_align); - __bang_atomic_add((T *)buffer + c_align, - image_offset + y_high * wo * c + x_high * c, - (T *)buffer + c_align, c); - } // x_low && y_low - } // ix - } // iy - } else { - for (int iy = 0; iy < roi_grid_h; ++iy) { - const float y = - y1 + ih * bin_size_h + (iy + 0.5) * bin_size_h / roi_grid_h; - for (int ix = 0; ix < roi_grid_w; ++ix) { - const float x = - x1 + iw * bin_size_w + (ix + 0.5) * bin_size_w / roi_grid_w; - float w1, w2, w3, w4; - int x_low, x_high, y_low, y_high; - bilinearInterpolateGradient(ho, wo, y, x, &w1, &w2, &w3, &w4, &x_low, - &x_high, &y_low, &y_high); - if (x_low >= 0 && y_low >= 0) { - int deal_once = - PAD_DOWN(MAX_NRAM_SIZE / 2, NFU_ALIGN_SIZE) / sizeof(T); - int c_repeat = c / deal_once + (int)(c % deal_once != 0); - for (int i = 0; i < c_repeat; ++i) { - int deal_c = deal_once; - int align_c = deal_once; - if (i == c_repeat - 1) { - deal_c = c - i * deal_once; - align_c = c_align - i * deal_once; - } - __memcpy(buffer, grads_ + i * deal_once, deal_c * sizeof(T), - GDRAM2NRAM); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w1, - align_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c, - 1 / count, align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_low * wo * c + x_low * c + i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w2, - align_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c, - 1 / count, align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_low * wo * c + x_high * c + i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w3, - align_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c, - 1 / count, align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_high * wo * c + x_low * c + i * deal_once, - (T *)buffer + align_c, deal_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer, (T)w4, - align_c); - __bang_mul_scalar((T *)buffer + align_c, (T *)buffer + align_c, - 1 / count, align_c); - __bang_atomic_add( - (T *)buffer + align_c, - image_offset + y_high * wo * c + x_high * c + i * deal_once, - (T *)buffer + align_c, deal_c); - } // for c_repeat - } // x_low >= 0 && y_low >= 0 - } // ix - } // iy - } // if c - } // i -} - -__mlu_global__ void MLUUnion1KernelRoiAlignBackward( - const void *grads, const void *boxes, void *grads_image, - const cnrtDataType_t dtype, const int boxes_num, const int hi, const int wi, - const int c, const int no, const int ho, const int wo, - const float spatial_scale, const int sampling_ratio, const bool aligned) { - // make sure that memcore is not used - if (coreId == 0x80) { - return; - } - switch (dtype) { - case CNRT_FLOAT16: { - unionRoiAlignBp((half *)grads, (half *)boxes, (half *)grads_image, - boxes_num, hi, wi, c, no, ho, wo, spatial_scale, - sampling_ratio, aligned); - }; break; - case CNRT_FLOAT32: { - unionRoiAlignBp((float *)grads, (float *)boxes, (float *)grads_image, - boxes_num, hi, wi, c, no, ho, wo, spatial_scale, - sampling_ratio, aligned); - }; break; - default: { return; } - } -} -} // namespace backward - -void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const void *input, const void *rois, const int channels, - const bool aligned, const int pooled_height, - const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, - void *output) { - forward::MLUUnion1KernelRoiAlignAvg<<>>( - input, rois, channels, aligned, pooled_height, pooled_width, input_height, - input_width, sampling_ratio, spatial_scale, num_rois, d_type, output); -} - -void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t dtype, - const void *grads, const void *boxes, - void *grads_image, const int boxes_num, - const int hi, const int wi, const int c, - const int no, const int ho, const int wo, - const float spatial_scale, const int sampling_ratio, - const bool aligned) { - backward::MLUUnion1KernelRoiAlignBackward<<>>( - grads, boxes, grads_image, dtype, boxes_num, hi, wi, c, no, ho, wo, - spatial_scale, sampling_ratio, aligned); -} diff --git a/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp index 5348d16e01..993aa5e410 100644 --- a/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/iou3d_mlu.cpp @@ -10,114 +10,30 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelIou3d(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t data_type_input, const void *boxes_dram, - const int input_box_num, const float iou_threshold, - void *workspace, void *output_size, void *output); - -int selectType(uint32_t use_job, int box_num_per_core) { - // the box_num_per_core should be at least 256, otherwise the real IO - // bandwidth would be very low - while (box_num_per_core < 256 && use_job >= 4) { - box_num_per_core *= 2; - use_job /= 2; - } - return use_job; -} -static cnnlStatus_t policyFunc(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, - int &core_num_per_class, - const int input_box_num) { - uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - uint32_t job_limit = getJobLimitCapability(); - uint32_t core_number = job_limit; - - int box_num_per_core = (input_box_num + core_number - 1) / core_number; - int use_job = selectType(job_limit, box_num_per_core); - // initiate k_type as Union1 - k_dim->x = core_dim; - k_dim->y = 1; - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; - switch (job_limit) { - case CN_KERNEL_CLASS_BLOCK: - case CN_KERNEL_CLASS_UNION: - case CN_KERNEL_CLASS_UNION2: - case CN_KERNEL_CLASS_UNION4: - case CN_KERNEL_CLASS_UNION8: - case CN_KERNEL_CLASS_UNION16: { - if (use_job < 4) { - k_dim->x = 1; - *k_type = CNRT_FUNC_TYPE_BLOCK; - } else if (use_job == 4) { - k_dim->x = core_dim; - *k_type = CNRT_FUNC_TYPE_UNION1; - } else { - k_dim->x = use_job; - *k_type = (cnrtFunctionType_t)use_job; - } - }; break; - default: - LOG(WARNING) << "[cnnlNms_v2]: got unsupported job limit number." - << " Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."; - } - return CNNL_STATUS_SUCCESS; -} +#include "mlu_common_helper.h" void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num, float iou_threshold) { - // dimension parameters check - TORCH_CHECK(boxes.dim() == 2, "boxes should be a 2d tensor, got ", - boxes.dim(), "D"); - TORCH_CHECK(boxes.size(1) == 7, - "boxes should have 7 elements in dimension 1, got ", - boxes.size(1)); - - // data type check - TORCH_CHECK( - boxes.scalar_type() == at::kFloat || boxes.scalar_type() == at::kHalf, - "data type of boxes should be Float or Half, got ", boxes.scalar_type()); - if (boxes.numel() == 0) { return; } - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(boxes.numel() < max_input_num, - "boxes.numel() should be less than 2147483648, got ", - boxes.numel()); - int input_box_num = boxes.size(0); - - cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); - cnrtDim3_t k_dim; - cnrtJobType_t k_type; - - int core_num_per_class; - policyFunc(&k_dim, &k_type, core_num_per_class, input_box_num); - // transpose boxes (n, 7) to (7, n) for better performance - auto boxes_t = boxes.transpose(0, 1); - auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes_t); - - auto output = at::empty({input_box_num}, boxes.options().dtype(at::kLong)); + int input_box_num = boxes.size(0); + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes); + auto output = keep.to(boxes.options().dtype(at::kInt)); auto output_size = at::empty({1}, boxes.options().dtype(at::kInt)); - // workspace - const int info_num = 7; // x, y,z, dx, dy, dz,angle - size_t space_size = 0; - if (boxes.scalar_type() == at::kHalf) { - space_size = input_box_num * sizeof(int16_t) * info_num + - input_box_num * sizeof(float) + sizeof(float); - } else { - space_size = input_box_num * sizeof(float) * (info_num + 1) + sizeof(float); - } + MluOpTensorDescriptor boxes_desc, output_desc; + boxes_desc.set(boxes_); + output_desc.set(output); - auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte)); + // workspace + size_t workspace_size = 0; + auto handle = mluOpGetCurrentHandle(); + mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), NULL, &workspace_size); + auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); // get compute queue - auto queue = torch_mlu::getCurQueue(); - auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); auto boxes_ptr = boxes_impl->cnnlMalloc(); auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); @@ -127,11 +43,29 @@ void IoU3DNMS3DMLUKernelLauncher(Tensor boxes, Tensor &keep, Tensor &keep_num, auto output_size_impl = torch_mlu::getMluTensorImpl(keep_num); auto output_size_ptr = output_size_impl->cnnlMalloc(); - uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - CNLOG(INFO) << "Launch Kernel KernelIou3d<<>>"; - KernelIou3d(k_dim, k_type, queue, data_type_input, boxes_ptr, input_box_num, - iou_threshold, workspace_ptr, output_size_ptr, output_ptr); + // nms desc + mluOpNmsDescriptor_t nms_desc; + const mluOpNmsBoxPointMode_t box_mode = (mluOpNmsBoxPointMode_t)0; + const mluOpNmsOutputMode_t output_mode = (mluOpNmsOutputMode_t)0; + const mluOpNmsAlgo_t algo = (mluOpNmsAlgo_t)0; + const mluOpNmsMethodMode_t method_mode = (mluOpNmsMethodMode_t)0; + const float soft_nms_sigma = 0.0; + const float confidence_threshold = 0.0; + const int input_layout = 0; + const bool pad_to_max_output_size = false; + const int max_output_size = input_box_num; + const float offset = 0.0; + + mluOpCreateNmsDescriptor(&nms_desc); + mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, + iou_threshold, soft_nms_sigma, max_output_size, + confidence_threshold, offset, input_layout, + pad_to_max_output_size); + + mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, NULL, NULL, + workspace_ptr, workspace_size, output_desc.desc(), output_ptr, + output_size_ptr); + mluOpDestroyNmsDescriptor(nms_desc); } void iou3d_nms3d_forward_mlu(const Tensor boxes, Tensor &keep, Tensor &keep_num, diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h index 436f055f04..37e125aacb 100644 --- a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -18,8 +18,8 @@ #include "pytorch_device_registry.hpp" #define MLUOP_MAJOR 0 -#define MLUOP_MINOR 5 -#define MLUOP_PATCHLEVEL 302 +#define MLUOP_MINOR 6 +#define MLUOP_PATCHLEVEL 0 mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input); diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index f8e884d971..ead293d1b5 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -9,495 +9,117 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ +#include "mlu_common_helper.h" #include "pytorch_device_registry.hpp" #include "pytorch_mlu_helper.hpp" -#define MIN(a, b) (((a) < (b)) ? (a) : (b)) - -typedef enum { - MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */ - MS_DEFORM_ATTN_FORWARD_DEFAULT = - 1, /*!< MLUKernelMsDeformAttnForwardDefault */ - MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL = - 2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */ -} MsDeformAttnForwardPolicy; - -void KernelMsDeformAttnForwardDefault( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char* data_value_gdram, - const char* data_spatial_shapes_gdram, - const char* data_level_start_index_gdram, - const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char* data_col_gdram); -void KernelMsDeformAttnForwardSmallChannel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const char* data_value_gdram, - const char* data_spatial_shapes_gdram, - const char* data_level_start_index_gdram, - const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, - const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, - const int32_t channels, const int32_t num_levels, const int32_t num_queries, - const int32_t num_points, char* data_col_gdram); - -typedef enum { - MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0, - MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1, -} MsDeformAttnBackwardKernelPolicy; - -MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( - const int32_t channels, const int32_t num_levels, const int32_t num_points, - const int32_t num_heads) { - const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - const int num_hlp = num_heads * num_levels * num_points; - int num_per_time_theory = (nram_size - num_levels * sizeof(float) - - 3 * num_levels * sizeof(int32_t)) / - sizeof(float) / (8 * PAD_UP(channels, 32) + 28) / - PAD_UP((num_hlp), 32); - if (num_per_time_theory >= 1) { - return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; - } - return MS_DEFORM_ATTN_BACKWARD_DEFAULT; -} - -void KernelMsDeformAttnBackwardDefaultKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float* data_value, - const int32_t* spatial_shapes, const int32_t* data_level_start_index, - const float* data_sampling_loc, const float* data_attn_weight, - const float* grad_output, const int32_t batch_size, const int32_t num_keys, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_queries, const int32_t num_points, float* grad_value, - float* grad_sampling_loc, float* grad_attn_weight); - -void KernelMsDeformAttnBackwardSmallChannelsKernel( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t d_type, const float* data_value, - const int32_t* spatial_shapes, const int32_t* data_level_start_index, - const float* data_sampling_loc, const float* data_attn_weight, - const float* grad_output, const int32_t batch, const int32_t spatial_size, - const int32_t num_heads, const int32_t channels, const int32_t num_levels, - const int32_t num_query, const int32_t num_points, float* grad_value, - float* grad_sampling_loc, float* grad_attn_weight); - -// policy function -MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( - cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size, - const int32_t num_keys, const int32_t num_heads, const int32_t channels, - const int32_t num_levels, const int32_t num_queries, - const int32_t num_points) { - k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->y = - MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x, - torch_mlu::getDeviceAttr(cnrtAttrClusterCount)); - k_dim->z = 1; -#if __BANG_ARCH__ == 520 - *k_type = CNRT_FUNC_TYPE_BLOCK; -#else - *k_type = CNRT_FUNC_TYPE_UNION1; -#endif - - int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); - if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { - return MS_DEFORM_ATTN_FORWARD_DEFAULT; - } else if (channels > nram_size / 12 / sizeof(float) || channels > 96 || - channels < 16) { - return MS_DEFORM_ATTN_FORWARD_DEFAULT; - } else { - return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; - } -} - -// policy function for backward -static void policyFuncBackward(const int32_t batch_size, - const int32_t num_queries, - const int32_t num_heads, - const int32_t num_levels, - cnrtFunctionType_t* k_type, cnrtDim3_t* k_dim) { - size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->x = core_limit; - int32_t total_num = batch_size * num_queries * num_heads * num_levels; - size_t total_num_align = CEIL_ALIGN(total_num, core_limit); - k_dim->y = (total_num_align / core_limit) > cluster_limit - ? cluster_limit - : (total_num_align / core_limit); - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; -} - -Tensor ms_deform_attn_mlu_forward(const Tensor& value, - const Tensor& spatial_shapes, - const Tensor& level_start_index, - const Tensor& sampling_loc, - const Tensor& attn_weight, - const int im2col_step) { - // check contiguous - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), - "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), - "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), - "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), - "attn_weight tensor has to be contiguous"); - - // check datatype - TORCH_CHECK((value.scalar_type() == at::kFloat), - "value type should be Float, got ", value.scalar_type(), "."); - TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt || - spatial_shapes.scalar_type() == at::kLong), - "spatial_shapes type should be Int, got ", - spatial_shapes.scalar_type(), "."); - TORCH_CHECK((level_start_index.scalar_type() == at::kInt || - level_start_index.scalar_type() == at::kLong), - "level_start_index type should be Int, got ", - level_start_index.scalar_type(), "."); - TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat), - "sampling_loc type should be Float, got ", - sampling_loc.scalar_type(), "."); - TORCH_CHECK((attn_weight.scalar_type() == at::kFloat), - "attn_weight type should be Float, got ", - attn_weight.scalar_type(), "."); - - // check shape - TORCH_CHECK(value.dim() == 4, "value should be a 4d tensor, got ", - value.dim(), "D."); - TORCH_CHECK(spatial_shapes.dim() == 2, - "spatial_shapes should be a 2d tensor, got ", - spatial_shapes.dim(), "D."); - TORCH_CHECK(level_start_index.dim() == 1, - "level_start_index should be a 1d tensor, got ", - level_start_index.dim(), "D."); - TORCH_CHECK(sampling_loc.dim() == 6, - "sampling_loc should be a 6d tensor, got ", sampling_loc.dim(), - "D."); - TORCH_CHECK(attn_weight.dim() == 5, "attn_weight should be a 5d tensor, got ", - attn_weight.dim(), "D."); - +/************************************************************************* + * This MACRO contains operations of simple tensor to mlu-tensor. + * _contiguous, _desc, _impl, _ptr will be automatically generated in + * this MACRO. + *************************************************************************/ +#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \ + auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \ + NAME, NAME.suggest_memory_format()); \ + MluOpTensorDescriptor NAME##_desc; \ + NAME##_desc.set(NAME##_contigous); \ + auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \ + auto NAME##_ptr = NAME##_impl->cnnlMalloc(); + +Tensor MsDeformAttnForwardLauncher(const Tensor& value, + const Tensor& spatial_shapes, + const Tensor& level_start_index, + const Tensor& sampling_loc, + const Tensor& attn_weight, + const int im2col_step) { + auto handle = mluOpGetCurrentHandle(); const int batch_size = value.size(0); - const int num_keys = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); - const int num_levels = spatial_shapes.size(0); const int num_queries = sampling_loc.size(1); - const int num_points = sampling_loc.size(4); - - TORCH_CHECK(spatial_shapes.size(1) == 2, - "the 2nd dimensions of spatial_shapes should be 2, got ", - spatial_shapes.size(1), "."); - TORCH_CHECK(sampling_loc.size(5) == 2, - "the 6th dimensions of sampling_loc should be 2, got ", - sampling_loc.size(5), "."); - TORCH_CHECK((sampling_loc.size(0) == batch_size), - "the 1st dimensions of sampling_loc should be batch_size, ", - "but now the 1st dimension of sampling_loc is ", - sampling_loc.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((attn_weight.size(0) == batch_size), - "the 1st dimensions of attn_weight should be batch_size, ", - "but now the 1st dimension of attn_weight is ", - attn_weight.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((sampling_loc.size(2) == num_heads), - "the 3rd dimensions of sampling_loc should be num_heads, ", - "but now the 3rd dimension of sampling_loc is ", - sampling_loc.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((attn_weight.size(2) == num_heads), - "the 3rd dimensions of attn_weight should be num_heads, ", - "but now the 3rd dimension of attn_weight is ", - attn_weight.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((level_start_index.size(0) == num_levels), - "the 1st dimensions of level_start_index should be num_levels, ", - "but now the 1st dimension of level_start_index is ", - level_start_index.size(0), ", and num_levels is ", num_levels, - "."); - TORCH_CHECK((sampling_loc.size(3) == num_levels), - "the 4th dimensions of sampling_loc should be num_levels, ", - "but now the 4th dimension of sampling_loc is ", - sampling_loc.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(3) == num_levels), - "the 4th dimensions of attn_weight should be num_levels, ", - "but now the 4th dimension of attn_weight is ", - attn_weight.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(1) == num_queries), - "the 2nd dimensions of attn_weight should be num_queries, ", - "but now the 2nd dimension of attn_weight is ", - attn_weight.size(1), ", and num_queries is ", num_queries, "."); - TORCH_CHECK((attn_weight.size(4) == num_points), - "the 5th dimensions of attn_weight should be num_points, ", - "but now the 5th dimension of attn_weight is ", - attn_weight.size(4), ", and num_points is ", num_points, "."); - auto output = at::zeros({batch_size, num_queries, num_heads, channels}, value.options()); - - // large tensor check - const size_t max_input_size = 2147483648; - TORCH_CHECK(value.numel() < max_input_size, - "value element num should be less than 2^31, got ", value.numel(), - "."); - TORCH_CHECK(sampling_loc.numel() < max_input_size, - "sampling_loc element num should be less than 2^31, got ", - sampling_loc.numel(), "."); - TORCH_CHECK(output.numel() < max_input_size, - "output element num should be less than 2^31, got ", - output.numel(), "."); - - // check zero element - TORCH_CHECK(batch_size != 0, "batch_size should not be zero"); - TORCH_CHECK(num_heads != 0, "num_heads should not be zero"); - TORCH_CHECK(channels != 0, "channels should not be zero"); - TORCH_CHECK(num_queries != 0, "num_queries should not be zero"); - - if (num_keys == 0 || num_levels == 0 || num_points == 0) { - return output; - } - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc( - &k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels, - num_queries, num_points); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - auto spatial_shapes_ = spatial_shapes.to(at::kInt); - auto level_start_index_ = level_start_index.to(at::kInt); - - // get ptr of tensors - auto value_impl = torch_mlu::getMluTensorImpl(value); - auto value_ptr = value_impl->cnnlMalloc(); - auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes_); - auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc(); - auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index_); - auto level_start_index_ptr = level_start_index_impl->cnnlMalloc(); - auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc); - auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc(); - auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight); - auto attn_weight_ptr = attn_weight_impl->cnnlMalloc(); - auto output_impl = torch_mlu::getMluTensorImpl(output); - auto output_ptr = output_impl->cnnlMalloc(); - - // get compute dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); - - // launch kernel - switch (policy) { - default: { - VLOG(5) << "MsDeformAttnForward Policy not supported"; - }; break; - case MS_DEFORM_ATTN_FORWARD_DEFAULT: { - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelMsDeformAttnForwardDefault( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); - break; - } - case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: { - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<" - << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - KernelMsDeformAttnForwardSmallChannel( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); - break; - } - } + auto spatial_shapes_int = spatial_shapes.to(at::kInt); + auto level_start_index_int = level_start_index.to(at::kInt); + INITIAL_MLU_PARAM_WITH_TENSOR(output); + INITIAL_MLU_PARAM_WITH_TENSOR(value); + INITIAL_MLU_PARAM_WITH_TENSOR(spatial_shapes_int); + INITIAL_MLU_PARAM_WITH_TENSOR(level_start_index_int); + INITIAL_MLU_PARAM_WITH_TENSOR(sampling_loc); + INITIAL_MLU_PARAM_WITH_TENSOR(attn_weight); + + mluOpMsDeformAttnForward( + handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(), + spatial_shapes_int_ptr, level_start_index_int_desc.desc(), + level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr, + attn_weight_desc.desc(), attn_weight_ptr, im2col_step, output_desc.desc(), + output_ptr); output = output.view({batch_size, num_queries, num_heads * channels}); return output; } -void ms_deform_attn_mlu_backward( +void MsDeformAttnBackwardLauncher( const Tensor& value, const Tensor& spatial_shapes, const Tensor& level_start_index, const Tensor& sampling_loc, const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, Tensor& grad_sampling_loc, Tensor& grad_attn_weight, const int im2col_step) { - // check contiguous - AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); - AT_ASSERTM(spatial_shapes.is_contiguous(), - "spatial_shapes tensor has to be contiguous"); - AT_ASSERTM(level_start_index.is_contiguous(), - "level_start_index tensor has to be contiguous"); - AT_ASSERTM(sampling_loc.is_contiguous(), - "sampling_loc tensor has to be contiguous"); - AT_ASSERTM(attn_weight.is_contiguous(), - "attn_weight tensor has to be contiguous"); - AT_ASSERTM(grad_output.is_contiguous(), - "grad_output tensor has to be contiguous"); - - // check datatype - TORCH_CHECK((value.scalar_type() == at::kFloat), - "value type should be Float, got ", value.scalar_type(), "."); - TORCH_CHECK((spatial_shapes.scalar_type() == at::kInt || - spatial_shapes.scalar_type() == at::kLong), - "spatial_shapes type should be Int, got ", - spatial_shapes.scalar_type(), "."); - TORCH_CHECK((level_start_index.scalar_type() == at::kInt || - level_start_index.scalar_type() == at::kLong), - "level_start_index type should be Int, got ", - level_start_index.scalar_type(), "."); - TORCH_CHECK((sampling_loc.scalar_type() == at::kFloat), - "sampling_loc type should be Float, got ", - sampling_loc.scalar_type(), "."); - TORCH_CHECK((attn_weight.scalar_type() == at::kFloat), - "attn_weight type should be Float, got ", - attn_weight.scalar_type(), "."); - TORCH_CHECK((grad_output.scalar_type() == at::kFloat), - "grad_output type should be Float, got ", - grad_output.scalar_type(), "."); - + auto handle = mluOpGetCurrentHandle(); + auto spatial_shapes_int = spatial_shapes.to(at::kInt); + auto level_start_index_int = level_start_index.to(at::kInt); const int batch_size = value.size(0); - const int num_keys = value.size(1); const int num_heads = value.size(2); const int channels = value.size(3); - const int num_levels = spatial_shapes.size(0); const int num_queries = sampling_loc.size(1); - const int num_points = sampling_loc.size(4); - // Check shape. - TORCH_CHECK(spatial_shapes.size(1) == 2, - "the 2nd dimensions of spatial_shapes should be 2, got ", - spatial_shapes.size(1), "."); - - TORCH_CHECK((level_start_index.size(0) == num_levels), - "the 1st dimensions of level_start_index should be num_levels, ", - "but now the 1st dimension of level_start_index is ", - level_start_index.size(0), ", and num_levels is ", num_levels, - "."); - - TORCH_CHECK((sampling_loc.size(0) == batch_size), - "the 1st dimensions of sampling_loc should be batch_size, ", - "but now the 1st dimension of sampling_loc is ", - sampling_loc.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((sampling_loc.size(2) == num_heads), - "the 3rd dimensions of sampling_loc should be num_heads, ", - "but now the 3rd dimension of sampling_loc is ", - sampling_loc.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((sampling_loc.size(3) == num_levels), - "the 4th dimensions of sampling_loc should be num_levels, ", - "but now the 4th dimension of sampling_loc is ", - sampling_loc.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK(sampling_loc.size(5) == 2, - "the 6th dimensions of sampling_loc should be 2, got ", - sampling_loc.size(5), "."); - - TORCH_CHECK((attn_weight.size(0) == batch_size), - "the 1st dimensions of attn_weight should be batch_size, ", - "but now the 1st dimension of attn_weight is ", - attn_weight.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((attn_weight.size(1) == num_queries), - "the 2nd dimensions of attn_weight should be num_queries, ", - "but now the 2nd dimension of attn_weight is ", - attn_weight.size(1), ", and num_queries is ", num_queries, "."); - - TORCH_CHECK((attn_weight.size(2) == num_heads), - "the 3rd dimensions of attn_weight should be num_heads, ", - "but now the 3rd dimension of attn_weight is ", - attn_weight.size(2), ", and num_heads is ", num_heads, "."); - TORCH_CHECK((attn_weight.size(3) == num_levels), - "the 4th dimensions of attn_weight should be num_levels, ", - "but now the 4th dimension of attn_weight is ", - attn_weight.size(3), ", and num_levels is ", num_levels, "."); - TORCH_CHECK((attn_weight.size(4) == num_points), - "the 5th dimensions of attn_weight should be num_points, ", - "but now the 5th dimension of attn_weight is ", - attn_weight.size(4), ", and num_points is ", num_points, "."); - - TORCH_CHECK((grad_output.size(0) == batch_size), - "the 1st dimensions of grad_output should be batch_size, ", - "but now the 1st dimension of grad_output is ", - grad_output.size(0), ", and batch_size is ", batch_size, "."); - TORCH_CHECK((grad_output.size(1) == num_queries), - "the 2nd dimensions of grad_output should be num_queries, ", - "but now the 2nd dimension of grad_output is ", - grad_output.size(1), ", and num_queries is ", num_queries, "."); - TORCH_CHECK( - (grad_output.size(2) == num_heads * channels), - "the 3rd dimensions of grad_output should be num_heads * channels, ", - "but now the 3rd dimension of grad_output is ", grad_output.size(2), - ", and num_heads * channels is ", num_heads * channels, "."); - - // check zero element - TORCH_CHECK(batch_size != 0, "The batch_size is zero."); - TORCH_CHECK(channels != 0, "The channels is zero."); - TORCH_CHECK(num_keys != 0, "The num_keys is zero."); - TORCH_CHECK(num_heads != 0, "The num_heads is zero."); - TORCH_CHECK(num_queries != 0, "The num_queries is zero."); - if (num_levels == 0 || num_points == 0) { - return; - } - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFuncBackward(batch_size, num_queries, num_heads, num_levels, &k_type, - &k_dim); - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - // get ptr of tensors - auto value_impl = torch_mlu::getMluTensorImpl(value); - auto value_ptr = value_impl->cnnlMalloc(); - auto spatial_shapes_impl = torch_mlu::getMluTensorImpl(spatial_shapes); - auto spatial_shapes_ptr = spatial_shapes_impl->cnnlMalloc(); - auto level_start_index_impl = torch_mlu::getMluTensorImpl(level_start_index); - auto level_start_index_ptr = level_start_index_impl->cnnlMalloc(); - auto sampling_loc_impl = torch_mlu::getMluTensorImpl(sampling_loc); - auto sampling_loc_ptr = sampling_loc_impl->cnnlMalloc(); - auto attn_weight_impl = torch_mlu::getMluTensorImpl(attn_weight); - auto attn_weight_ptr = attn_weight_impl->cnnlMalloc(); - auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output); - auto grad_output_ptr = grad_output_impl->cnnlMalloc(); - auto grad_value_impl = torch_mlu::getMluTensorImpl(grad_value); - auto grad_value_ptr = grad_value_impl->cnnlMalloc(); - auto grad_sampling_loc_impl = torch_mlu::getMluTensorImpl(grad_sampling_loc); - auto grad_sampling_loc_ptr = grad_sampling_loc_impl->cnnlMalloc(); - auto grad_attn_weight_impl = torch_mlu::getMluTensorImpl(grad_attn_weight); - auto grad_attn_weight_ptr = grad_attn_weight_impl->cnnlMalloc(); + auto grad_output_dim4 = + grad_output.view({batch_size, num_queries, num_heads, channels}); + // auto grad_output_dim4 = grad_output.view({batch_size, num_queries, + // num_heads, channels}).detach(); + INITIAL_MLU_PARAM_WITH_TENSOR(value); + INITIAL_MLU_PARAM_WITH_TENSOR(spatial_shapes_int); + INITIAL_MLU_PARAM_WITH_TENSOR(level_start_index_int); + INITIAL_MLU_PARAM_WITH_TENSOR(sampling_loc); + INITIAL_MLU_PARAM_WITH_TENSOR(attn_weight); + INITIAL_MLU_PARAM_WITH_TENSOR(grad_output_dim4); + // INITIAL_MLU_PARAM_WITH_TENSOR(grad_output); + INITIAL_MLU_PARAM_WITH_TENSOR(grad_value); + INITIAL_MLU_PARAM_WITH_TENSOR(grad_sampling_loc); + INITIAL_MLU_PARAM_WITH_TENSOR(grad_attn_weight); + + mluOpMsDeformAttnBackward( + handle, value_desc.desc(), value_ptr, spatial_shapes_int_desc.desc(), + spatial_shapes_int_ptr, level_start_index_int_desc.desc(), + level_start_index_int_ptr, sampling_loc_desc.desc(), sampling_loc_ptr, + attn_weight_desc.desc(), attn_weight_ptr, grad_output_dim4_desc.desc(), + grad_output_dim4_ptr, im2col_step, grad_value_desc.desc(), grad_value_ptr, + grad_sampling_loc_desc.desc(), grad_sampling_loc_ptr, + grad_attn_weight_desc.desc(), grad_attn_weight_ptr); + + return; +} - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); +Tensor ms_deform_attn_mlu_forward(const Tensor& value, + const Tensor& spatial_shapes, + const Tensor& level_start_index, + const Tensor& sampling_loc, + const Tensor& attn_weight, + const int im2col_step) { + return MsDeformAttnForwardLauncher(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, im2col_step); +} - // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - MsDeformAttnBackwardKernelPolicy kernelPolicy = - msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points, - num_heads); - switch (kernelPolicy) { - default: { - VLOG(5) << "NotImplemented."; - } break; - case MS_DEFORM_ATTN_BACKWARD_DEFAULT: { - KernelMsDeformAttnBackwardDefaultKernel( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); - } break; - case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: { - KernelMsDeformAttnBackwardSmallChannelsKernel( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); - } break; - } +void ms_deform_attn_mlu_backward( + const Tensor& value, const Tensor& spatial_shapes, + const Tensor& level_start_index, const Tensor& sampling_loc, + const Tensor& attn_weight, const Tensor& grad_output, Tensor& grad_value, + Tensor& grad_sampling_loc, Tensor& grad_attn_weight, + const int im2col_step) { + return MsDeformAttnBackwardLauncher(value, spatial_shapes, level_start_index, + sampling_loc, attn_weight, grad_output, + grad_value, grad_sampling_loc, + grad_attn_weight, im2col_step); } Tensor ms_deform_attn_impl_forward(const Tensor& value, @@ -515,5 +137,6 @@ void ms_deform_attn_impl_backward( REGISTER_DEVICE_IMPL(ms_deform_attn_impl_forward, MLU, ms_deform_attn_mlu_forward); + REGISTER_DEVICE_IMPL(ms_deform_attn_impl_backward, MLU, ms_deform_attn_mlu_backward); diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp index e2f4322a02..eff6793f2d 100644 --- a/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/nms_mlu.cpp @@ -10,123 +10,35 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelNms(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const cnrtDataType_t data_type_input, const void *boxes_ptr, - const void *scores_ptr, const int input_num_boxes, - const int max_output_boxes, const float iou_threshold, - const float offset, void *workspace_ptr, void *output_size_ptr, - void *output_ptr); - -int selectUnionType(uint32_t use_job, int box_num_per_core) { - // the box_num_per_core should be at least 256, otherwise the real IO - // bandwidth would be very low - while (box_num_per_core < 256 && use_job >= 4) { - box_num_per_core *= 2; - use_job /= 2; - } - return use_job; -} - -static cnnlStatus_t policyFunc(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, - int &core_num_per_class, - const int input_box_num) { - uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - uint32_t cluster_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - uint32_t job_limit = getJobLimitCapability(); - uint32_t core_number = job_limit; - - int box_num_per_core = (input_box_num + core_number - 1) / core_number; - int use_job = selectUnionType(job_limit, box_num_per_core); - // initiate k_type as Union1 - k_dim->x = core_dim; - k_dim->y = 1; - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; - switch (job_limit) { - case CN_KERNEL_CLASS_BLOCK: - case CN_KERNEL_CLASS_UNION: - case CN_KERNEL_CLASS_UNION2: - case CN_KERNEL_CLASS_UNION4: - case CN_KERNEL_CLASS_UNION8: - case CN_KERNEL_CLASS_UNION16: { - if (use_job < 4) { - k_dim->x = 1; - *k_type = CNRT_FUNC_TYPE_BLOCK; - } else if (use_job == 4) { - k_dim->x = core_dim; - *k_type = CNRT_FUNC_TYPE_UNION1; - } else { - k_dim->x = use_job; - *k_type = (cnrtFunctionType_t)use_job; - } - }; break; - default: - LOG(WARNING) << "[cnnlNms_v2]: got unsupported job limit number." - << " Use default CN_KERNEL_CLASS_UNION1 with UNION1 task."; - } - return CNNL_STATUS_SUCCESS; -} +#include "mlu_common_helper.h" Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, int offset) { - // dimension parameters check - TORCH_CHECK(boxes.dim() == 2, "boxes should be a 2d tensor, got ", - boxes.dim(), "D"); - TORCH_CHECK(boxes.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - boxes.size(1)); - TORCH_CHECK(scores.dim() == 1, "scores should be a 1d tensor, got ", - scores.dim(), "D"); - - // data type check - TORCH_CHECK(boxes.scalar_type() == scores.scalar_type(), - "boxes should have the same type as scores"); - TORCH_CHECK( - boxes.scalar_type() == at::kFloat || boxes.scalar_type() == at::kHalf, - "data type of boxes should be Float or Half, got ", boxes.scalar_type()); - if (boxes.numel() == 0) { return at::empty({0}, boxes.options().dtype(at::kLong)); } - int input_num_boxes = boxes.size(0); int max_output_boxes = boxes.size(0); - cnrtDataType_t data_type_input = torch_mlu::toCnrtDtype(boxes.dtype()); - cnrtDim3_t k_dim; - cnrtJobType_t k_type; - - int core_num_per_class; - policyFunc(&k_dim, &k_type, core_num_per_class, input_num_boxes); - // transpose boxes (n, 4) to (4, n) for better performance - auto boxes_t = boxes.transpose(0, 1); - auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes_t); + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes); auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); - auto output = at::empty({max_output_boxes}, boxes.options().dtype(at::kLong)); + auto output = at::empty({max_output_boxes}, boxes.options().dtype(at::kInt)); auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); + MluOpTensorDescriptor boxes_desc, scores_desc, output_desc; + boxes_desc.set(boxes_); + scores_desc.set(scores_); + output_desc.set(output); + // workspace - const int info_num = 5; // x1, x2, y1, y2 and score - size_t space_size = 0; - if (boxes.scalar_type() == at::kHalf) { - space_size = input_num_boxes * sizeof(int16_t) * info_num + sizeof(float); - } else { - space_size = input_num_boxes * sizeof(float) * info_num + sizeof(float); - } -#if __BANG_ARCH__ > 370 - int cluster_num = getCoreNumOfJobLimitCapability() / - torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - space_size += cluster_number * sizeof(float) * 7; -#endif - auto workspace = at::empty(space_size, boxes.options().dtype(at::kByte)); + size_t workspace_size = 0; + auto handle = mluOpGetCurrentHandle(); + mluOpGetNmsWorkspaceSize(handle, boxes_desc.desc(), scores_desc.desc(), + &workspace_size); + auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); // get compute queue - auto queue = torch_mlu::getCurQueue(); - auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); auto boxes_ptr = boxes_impl->cnnlMalloc(); auto scores_impl = torch_mlu::getMluTensorImpl(scores_); @@ -138,14 +50,31 @@ Tensor NMSMLUKernelLauncher(Tensor boxes, Tensor scores, float iou_threshold, auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); auto output_size_ptr = output_size_impl->cnnlMalloc(); - uint32_t core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - CNLOG(INFO) << "Launch Kernel MLUUnionX NMS<<>>"; - KernelNms(k_dim, k_type, queue, data_type_input, boxes_ptr, scores_ptr, - input_num_boxes, max_output_boxes, iou_threshold, offset, - workspace_ptr, output_size_ptr, output_ptr); + // nms desc + mluOpNmsDescriptor_t nms_desc; + const mluOpNmsBoxPointMode_t box_mode = (mluOpNmsBoxPointMode_t)0; + const mluOpNmsOutputMode_t output_mode = (mluOpNmsOutputMode_t)0; + const mluOpNmsAlgo_t algo = (mluOpNmsAlgo_t)0; + const mluOpNmsMethodMode_t method_mode = (mluOpNmsMethodMode_t)0; + const float soft_nms_sigma = 0.0; + const float confidence_threshold = 0.0; + const int input_layout = 0; + const bool pad_to_max_output_size = false; + const int max_output_size = max_output_boxes; + + mluOpCreateNmsDescriptor(&nms_desc); + mluOpSetNmsDescriptor(nms_desc, box_mode, output_mode, algo, method_mode, + iou_threshold, soft_nms_sigma, max_output_size, + confidence_threshold, (float)offset, input_layout, + pad_to_max_output_size); + + mluOpNms(handle, nms_desc, boxes_desc.desc(), boxes_ptr, scores_desc.desc(), + scores_ptr, workspace_ptr, workspace_size, output_desc.desc(), + output_ptr, output_size_ptr); + mluOpDestroyNmsDescriptor(nms_desc); int output_num = *static_cast(output_size.cpu().data_ptr()); - return output.slice(0, 0, output_num); + auto ret = output.to(boxes.options().dtype(at::kLong)); + return ret.slice(0, 0, output_num); } Tensor nms_mlu(Tensor boxes, Tensor scores, float iou_threshold, int offset) { diff --git a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp index 361bba25f6..ff6e5b1500 100644 --- a/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/roi_align_mlu.cpp @@ -9,26 +9,7 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelRoiAlign(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const void *input, const void *rois, const int channels, - const bool aligned, const int pooled_height, - const int pooled_width, const int input_height, - const int input_width, const int sampling_ratio, - const float spatial_scale, const int num_rois, - void *output); - -void KernelRoiAlignBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t dtype, - const void *grads, const void *boxes, - void *grads_image, const int boxes_num, - const int hi, const int wi, const int c, - const int no, const int ho, const int wo, - const float spatial_scale, const int sampling_ratio, - const bool aligned); +#include "mlu_common_helper.h" void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, Tensor argmax_y, Tensor argmax_x, @@ -36,17 +17,7 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, float spatial_scale, int sampling_ratio, int pool_mode, bool aligned) { // params check - TORCH_CHECK( - input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, - "input type should be Float or Half, got ", input.scalar_type()); - TORCH_CHECK(rois.scalar_type() == input.scalar_type(), - "rois should have the same type as input"); - TORCH_CHECK(input.dim() == 4, "input should be a 4d tensor, got ", - input.dim(), "D"); - TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), - "D"); TORCH_CHECK(pool_mode == 1, "pool_mode only supports 'avg' currently"); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); auto input_tensor = @@ -57,52 +28,56 @@ void ROIAlignForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor output, int height = input.size(2); int width = input.size(3); - if (output.numel() == 0) { - output = at::zeros({num_rois, channels, aligned_height, aligned_width}, - input.options()); - return; - } - - at::Tensor output_tmp = + auto output_contiguous = at::empty({num_rois, channels, aligned_height, aligned_width}, input.options(), memory_format); - // get tensor impl auto self_impl = torch_mlu::getMluTensorImpl(input_tensor); auto rois_impl = torch_mlu::getMluTensorImpl(rois); - auto output_impl = torch_mlu::getMluTensorImpl(output_tmp); + auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous); - // get compute queue - auto queue = torch_mlu::getCurQueue(); + MluOpTensorDescriptor input_desc, rois_desc, argmax_y_desc, argmax_x_desc, + output_desc; + input_desc.set_with_layout(input_tensor, MLUOP_LAYOUT_NHWC); + rois_desc.set_with_layout(rois, MLUOP_LAYOUT_ARRAY); + output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC); // get the mlu ptr auto self_ptr = self_impl->cnnlMalloc(); auto rois_ptr = rois_impl->cnnlMalloc(); auto output_ptr = output_impl->cnnlMalloc(); - cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; - cnrtDim3_t k_dim; - k_dim.x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim.y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - k_dim.z = 1; - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype()); - - KernelRoiAlign(k_dim, k_type, queue, data_type, self_ptr, rois_ptr, channels, - aligned, aligned_height, aligned_width, height, width, - sampling_ratio, spatial_scale, num_rois, output_ptr); - - output.copy_(output_tmp); -} - -static int nearestPower2(int x) { - x--; - x |= x >> 1; - x |= x >> 2; - x |= x >> 4; - x |= x >> 8; - x |= x >> 16; - x++; - return x; + mluOpRoiAlignForwardDescriptor_t roialign_desc; + mluOpCreateRoiAlignForwardDescriptor(&roialign_desc); + mluOpSetRoiAlignForwardDescriptor_v2(roialign_desc, aligned_height, + aligned_width, sampling_ratio, + spatial_scale, pool_mode, aligned); + + auto handle = mluOpGetCurrentHandle(); + if (pool_mode == 0) { + auto argmax_y_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(argmax_y, memory_format); + auto argmax_x_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(argmax_x, memory_format); + auto argmax_x_impl = torch_mlu::getMluTensorImpl(argmax_x_contiguous); + auto argmax_y_impl = torch_mlu::getMluTensorImpl(argmax_y_contiguous); + auto argmax_x_ptr = argmax_x_impl->cnnlMalloc(); + auto argmax_y_ptr = argmax_y_impl->cnnlMalloc(); + argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); + argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); + mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, + rois_desc.desc(), rois_ptr, output_desc.desc(), + output_ptr, argmax_x_desc.desc(), argmax_x_ptr, + argmax_y_desc.desc(), argmax_y_ptr); + argmax_x.copy_(argmax_x_contiguous); + argmax_y.copy_(argmax_y_contiguous); + } else { + mluOpRoiAlignForward_v2(handle, roialign_desc, input_desc.desc(), self_ptr, + rois_desc.desc(), rois_ptr, output_desc.desc(), + output_ptr, NULL, NULL, NULL, NULL); + } + mluOpDestroyRoiAlignForwardDescriptor(roialign_desc); + output.copy_(output_contiguous); } void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois, @@ -112,17 +87,7 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois, int sampling_ratio, int pool_mode, bool aligned) { // params check - TORCH_CHECK( - grad.scalar_type() == at::kFloat || grad.scalar_type() == at::kHalf, - "grad type should be Float or Half, got ", grad.scalar_type()); - TORCH_CHECK(rois.scalar_type() == grad.scalar_type(), - "rois should have the same type as grad"); - TORCH_CHECK(grad.dim() == 4, "grad should be a 4d tensor, got ", grad.dim(), - "D"); - TORCH_CHECK(rois.dim() == 2, "rois should be a 2d tensor, got ", rois.dim(), - "D"); TORCH_CHECK(pool_mode == 1, "pool_mode only supports 'avg' currently"); - int batch_size = grad_input.size(0); int channels = grad_input.size(1); int height = grad_input.size(2); @@ -148,26 +113,40 @@ void ROIAlignBackwardMLUKernelLauncher(Tensor grad, Tensor rois, auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); auto rois_impl = torch_mlu::getMluTensorImpl(rois); - // get compute queue - auto queue = torch_mlu::getCurQueue(); - // get the mlu ptr auto grad_ptr = grad_impl->cnnlMalloc(); auto rois_ptr = rois_impl->cnnlMalloc(); auto grad_input_ptr = grad_input_impl->cnnlMalloc(); - cnrtJobType_t k_type = CNRT_FUNC_TYPE_UNION1; - int need_core = nearestPower2(boxes_num); - int union_number = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - uint32_t dim_x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - uint32_t dim_y = (need_core - 1) / dim_x + 1; - dim_y = (dim_y > union_number) ? union_number : dim_y; - cnrtDim3_t k_dim = {dim_x, dim_y, 1}; - cnrtDataType_t k_dtype = torch_mlu::toCnrtDtype(grad.dtype()); - - KernelRoiAlignBackward(k_dim, k_type, queue, k_dtype, grad_ptr, rois_ptr, - grad_input_ptr, boxes_num, hi, wi, c, no, ho, wo, - spatial_scale, sampling_ratio, aligned); + MluOpTensorDescriptor grads_desc, rois_desc, argmax_y_desc, argmax_x_desc, + grad_input_desc; + grads_desc.set_with_layout(grad_, MLUOP_LAYOUT_NHWC); + rois_desc.set_with_layout(rois, MLUOP_LAYOUT_ARRAY); + grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC); + + auto handle = mluOpGetCurrentHandle(); + if (pool_mode == 0) { + auto argmax_y_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(argmax_y, memory_format); + auto argmax_x_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(argmax_x, memory_format); + auto argmax_x_impl = torch_mlu::getMluTensorImpl(argmax_x_contiguous); + auto argmax_y_impl = torch_mlu::getMluTensorImpl(argmax_y_contiguous); + auto argmax_x_ptr = argmax_x_impl->cnnlMalloc(); + auto argmax_y_ptr = argmax_y_impl->cnnlMalloc(); + argmax_y_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); + argmax_x_desc.set_with_layout(argmax_x_contiguous, MLUOP_LAYOUT_NHWC); + mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, + rois_desc.desc(), rois_ptr, argmax_y_desc.desc(), + argmax_x_ptr, argmax_y_desc.desc(), argmax_y_ptr, + spatial_scale, sampling_ratio, aligned, pool_mode, + grad_input_desc.desc(), grad_input_ptr); + } else { + mluOpRoiAlignBackward_v2(handle, grads_desc.desc(), grad_ptr, + rois_desc.desc(), rois_ptr, NULL, NULL, NULL, NULL, + spatial_scale, sampling_ratio, aligned, pool_mode, + grad_input_desc.desc(), grad_input_ptr); + } grad_input.copy_(grad_input_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp index c3d31bc0e5..2ffd751ade 100644 --- a/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp @@ -9,238 +9,69 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" +#include "mlu_common_helper.h" -#define MIN(a, b) (((a) < (b)) ? (a) : (b)) - -void KernelDynamicVoxelize( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - const void *points, void *coors, const float voxel_x, const float voxel_y, - const float voxel_z, const float coors_x_min, const float coors_y_min, - const float coors_z_min, const float coors_x_max, const float coors_y_max, - const float coors_z_max, const int32_t grid_x, const int32_t grid_y, - const int32_t grid_z, const int32_t num_points, const int32_t num_features); - -void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, void *coors, void *point_to_pointidx, - void *point_to_voxelidx, const int32_t num_points, - const int32_t max_points); - -void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, void *point_to_pointidx, - void *point_to_voxelidx, void *coor_to_voxelidx, - void *num_points_per_voxel, void *voxel_num, - const int32_t max_voxels, - const int32_t num_points); - -void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const void *points, - void *temp_coors, void *point_to_voxelidx, - void *coor_to_voxelidx, void *voxels, void *coors, - const int32_t max_points, const int32_t num_points, - const int32_t num_features); - -// policy function -static void policyFuncDefault(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, - const int num_points) { - k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->y = MIN((num_points + k_dim->x - 1) / k_dim->x, - torch_mlu::getDeviceAttr(cnrtAttrClusterCount)); - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; -} - -// policy function -static void policyFuncCalcPointsPerVoxel(cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type, - const int num_points) { - k_dim->x = 1; - k_dim->y = 1; - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_BLOCK; -} +/************************************************************************* + * This MACRO contains operations of simple tensor to mlu-tensor. + * _contiguous, _desc, _impl, _ptr will be automatically generated in + * this MACRO. + *************************************************************************/ +#define INITIAL_MLU_PARAM_WITH_TENSOR(NAME) \ + auto NAME##_contigous = torch_mlu::cnnl::ops::cnnl_contiguous( \ + NAME, NAME.suggest_memory_format()); \ + MluOpTensorDescriptor NAME##_desc; \ + NAME##_desc.set(NAME##_contigous); \ + auto NAME##_impl = torch_mlu::getMluTensorImpl(NAME##_contigous); \ + auto NAME##_ptr = NAME##_impl->cnnlMalloc(); int HardVoxelizeForwardMLUKernelLauncher( const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, at::Tensor &num_points_per_voxel, const std::vector voxel_size, const std::vector coors_range, const int max_points, const int max_voxels, const int NDim = 3) { - // check datatype - TORCH_CHECK(points.scalar_type() == at::kFloat, - "points type should be Float, got ", points.scalar_type(), "."); - TORCH_CHECK(voxels.scalar_type() == at::kFloat, - "voxels type should be Float, got ", voxels.scalar_type(), "."); - TORCH_CHECK(coors.scalar_type() == at::kInt, - "coors type should be Float, got ", coors.scalar_type(), "."); - TORCH_CHECK(num_points_per_voxel.scalar_type() == at::kInt, - "num_points_per_voxel type should be Float, got ", - num_points_per_voxel.scalar_type(), "."); - - // check shape - TORCH_CHECK(points.dim() == 2, "points should be a 2d tensor, got ", - points.dim(), "D."); - TORCH_CHECK(voxels.dim() == 3, "voxels should be a 3d tensor, got ", - voxels.dim(), "D."); - TORCH_CHECK(coors.dim() == 2, "coors should be a 2d tensor, got ", - coors.dim(), "D."); - TORCH_CHECK(num_points_per_voxel.dim() == 1, - "num_points_per_voxel should be a 1d tensor, got ", - num_points_per_voxel.dim(), "D."); - - const int num_points = points.size(0); - const int num_features = points.size(1); - - TORCH_CHECK(points.size(0) == num_points, - "the 1st dimensions of points should be num_points, got ", - points.size(0), "."); - TORCH_CHECK(points.size(1) == num_features, - "the 2nd dimensions of points should be num_features, got ", - points.size(1), "."); - TORCH_CHECK(voxels.size(0) == max_voxels, - "the 1st dimensions of voxels should be max_voxels, got ", - voxels.size(0), "."); - TORCH_CHECK(voxels.size(1) == max_points, - "the 2nd dimensions of voxels should be max_points, got ", - voxels.size(1), "."); - TORCH_CHECK(voxels.size(2) == num_features, - "the 3rd dimensions of voxels should be num_features, got ", - voxels.size(2), "."); - TORCH_CHECK(coors.size(0) == max_voxels, - "the 1st dimensions of coors should be max_voxels, got ", - coors.size(0), "."); - TORCH_CHECK(coors.size(1) == 3, - "the 2nd dimensions of coors should be 3, got ", coors.size(1), - "."); - TORCH_CHECK(num_points_per_voxel.size(0) == max_voxels, - "the 1st dimensions of num_points_per_voxel should be 3, got ", - num_points_per_voxel.size(0), "."); - - // large tensor check - const size_t max_input_size = 2147483648; - TORCH_CHECK(points.numel() < max_input_size, - "points element num should be less than 2^31, got ", - points.numel(), "."); - TORCH_CHECK(voxels.numel() < max_input_size, - "voxels element num should be less than 2^31, got ", - voxels.numel(), "."); - TORCH_CHECK(coors.numel() < max_input_size, - "coors element num should be less than 2^31, got ", coors.numel(), - "."); - - // check zero element - if (max_points == 0 || max_voxels == 0) { - return 0; - } - - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - // get ptr of tensors - auto points_ = points.contiguous(); - auto points_impl = torch_mlu::getMluTensorImpl(points_); - auto points_ptr = points_impl->cnnlMalloc(); - auto voxels_ = voxels.contiguous(); - auto voxels_impl = torch_mlu::getMluTensorImpl(voxels_); - auto voxels_ptr = voxels_impl->cnnlMalloc(); - auto coors_ = coors.contiguous(); - auto coors_impl = torch_mlu::getMluTensorImpl(coors_); - auto coors_ptr = coors_impl->cnnlMalloc(); - auto num_points_per_voxel_ = num_points_per_voxel.contiguous(); - auto num_points_per_voxel_impl = - torch_mlu::getMluTensorImpl(num_points_per_voxel_); - auto num_points_per_voxel_ptr = num_points_per_voxel_impl->cnnlMalloc(); - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFuncDefault(&k_dim, &k_type, num_points); - - // 1. link point to corresponding voxel coors - const float voxel_x = voxel_size[0]; - const float voxel_y = voxel_size[1]; - const float voxel_z = voxel_size[2]; - const float coors_x_min = coors_range[0]; - const float coors_y_min = coors_range[1]; - const float coors_z_min = coors_range[2]; - const float coors_x_max = coors_range[3]; - const float coors_y_max = coors_range[4]; - const float coors_z_max = coors_range[5]; - - const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); - const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); - const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); - - auto temp_coors = - at::zeros({NDim, num_points}, points.options().dtype(at::kInt)) - .contiguous(); - auto temp_coors_impl = torch_mlu::getMluTensorImpl(temp_coors); - auto temp_coors_ptr = temp_coors_impl->cnnlMalloc(); - - KernelDynamicVoxelize(k_dim, k_type, queue, points_ptr, temp_coors_ptr, - voxel_x, voxel_y, voxel_z, coors_x_min, coors_y_min, - coors_z_min, coors_x_max, coors_y_max, coors_z_max, - grid_x, grid_y, grid_z, num_points, num_features); - - // 2. map point to the idx of the corresponding voxel, find duplicate coor - auto point_to_pointidx = at::zeros( - { - num_points, - }, - points.options().dtype(at::kInt)) - .contiguous(); - auto point_to_pointidx_impl = torch_mlu::getMluTensorImpl(point_to_pointidx); - auto point_to_pointidx_ptr = point_to_pointidx_impl->cnnlMalloc(); - auto point_to_voxelidx = at::zeros( - { - num_points, - }, - points.options().dtype(at::kInt)) - .contiguous(); - auto point_to_voxelidx_impl = torch_mlu::getMluTensorImpl(point_to_voxelidx); - auto point_to_voxelidx_ptr = point_to_voxelidx_impl->cnnlMalloc(); - - KernelPoint2Voxel(k_dim, k_type, queue, temp_coors_ptr, point_to_pointidx_ptr, - point_to_voxelidx_ptr, num_points, max_points); - - // calculate task dimension - cnrtDim3_t k_dim_calc_points_per_voxel; - cnrtFunctionType_t k_type_calc_points_per_voxel; - policyFuncCalcPointsPerVoxel(&k_dim_calc_points_per_voxel, - &k_type_calc_points_per_voxel, num_points); - - // 3. determine voxel num and voxel's coor index - auto coor_to_voxelidx = at::zeros( - { - num_points, - }, - points.options().dtype(at::kInt)) - .contiguous(); - auto coor_to_voxelidx_impl = torch_mlu::getMluTensorImpl(coor_to_voxelidx); - auto coor_to_voxelidx_ptr = coor_to_voxelidx_impl->cnnlMalloc(); - auto voxel_num = at::zeros( - { - 1, - }, - points.options().dtype(at::kInt)) - .contiguous(); - auto voxel_num_impl = torch_mlu::getMluTensorImpl(voxel_num); - auto voxel_num_ptr = voxel_num_impl->cnnlMalloc(); - - KernelCalcPointsPerVoxel( - k_dim_calc_points_per_voxel, k_type_calc_points_per_voxel, queue, - point_to_pointidx_ptr, point_to_voxelidx_ptr, coor_to_voxelidx_ptr, - num_points_per_voxel_ptr, voxel_num_ptr, max_voxels, num_points); - - // 4. copy point features and coors of each voxels to voxels - KernelAssignVoxelsCoors(k_dim, k_type, queue, points_ptr, temp_coors_ptr, - point_to_voxelidx_ptr, coor_to_voxelidx_ptr, - voxels_ptr, coors_ptr, max_points, num_points, - num_features); - - auto voxel_num_cpu = voxel_num.to(at::kCPU); + std::vector _voxel_size(voxel_size.begin(), voxel_size.end()); + std::vector _coors_range(coors_range.begin(), coors_range.end()); + auto opts = torch::TensorOptions().dtype(torch::kFloat32); + auto voxel_size_tensor = + torch::from_blob(_voxel_size.data(), {int64_t(_voxel_size.size())}, opts) + .clone() + .to(at::kMLU); + auto coors_range_tensor = + torch::from_blob(_coors_range.data(), {int64_t(_coors_range.size())}, + opts) + .clone() + .to(at::kMLU); + INITIAL_MLU_PARAM_WITH_TENSOR(points); + INITIAL_MLU_PARAM_WITH_TENSOR(voxels); + INITIAL_MLU_PARAM_WITH_TENSOR(coors); + INITIAL_MLU_PARAM_WITH_TENSOR(num_points_per_voxel); + INITIAL_MLU_PARAM_WITH_TENSOR(voxel_size_tensor); + INITIAL_MLU_PARAM_WITH_TENSOR(coors_range_tensor); + + auto voxel_num_tensor = at::empty({1}, points.options().dtype(torch::kInt32)); + INITIAL_MLU_PARAM_WITH_TENSOR(voxel_num_tensor); + + size_t workspace_size; + auto handle = mluOpGetCurrentHandle(); + mluOpGetVoxelizationWorkspaceSize( + handle, points_desc.desc(), voxel_size_tensor_desc.desc(), + coors_range_tensor_desc.desc(), max_points, max_voxels, NDim, true, + voxels_desc.desc(), coors_desc.desc(), num_points_per_voxel_desc.desc(), + voxel_num_tensor_desc.desc(), &workspace_size); + auto workspace_tensor = + at::empty(workspace_size, points.options().dtype(at::kByte)); + INITIAL_MLU_PARAM_WITH_TENSOR(workspace_tensor); + + mluOpVoxelization(handle, points_desc.desc(), points_ptr, + voxel_size_tensor_desc.desc(), voxel_size_tensor_ptr, + coors_range_tensor_desc.desc(), coors_range_tensor_ptr, + max_points, max_voxels, NDim, true, workspace_tensor_ptr, + workspace_size, voxels_desc.desc(), voxels_ptr, + coors_desc.desc(), coors_ptr, + num_points_per_voxel_desc.desc(), num_points_per_voxel_ptr, + voxel_num_tensor_desc.desc(), voxel_num_tensor_ptr); + auto voxel_num_cpu = voxel_num_tensor.to(at::kCPU); int voxel_num_int = voxel_num_cpu.data_ptr()[0]; - return voxel_num_int; } @@ -254,7 +85,7 @@ int hard_voxelize_forward_mlu(const at::Tensor &points, at::Tensor &voxels, return HardVoxelizeForwardMLUKernelLauncher( points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, max_points, max_voxels, NDim); -}; +} int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, diff --git a/setup.py b/setup.py index 6040117e6c..a95f460196 100644 --- a/setup.py +++ b/setup.py @@ -212,6 +212,7 @@ def get_extensions(): include_dirs = [] extra_objects = [] + extra_link_args = [] is_rocm_pytorch = False try: from torch.utils.cpp_extension import ROCM_HOME @@ -325,8 +326,11 @@ def get_mluops_version(file_path): './mlu-ops/bangc-ops/kernels/**/*.cpp', recursive=True) + \ glob.glob( './mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True) - extra_objects = glob.glob( - './mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o') + extra_link_args = [ + '-Wl,--whole-archive', + './mlu-ops/bangc-ops/kernels/kernel_wrapper/lib/libextops.a', + '-Wl,--no-whole-archive' + ] extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) @@ -393,7 +397,8 @@ def get_mluops_version(file_path): include_dirs=include_dirs, define_macros=define_macros, extra_objects=extra_objects, - extra_compile_args=extra_compile_args) + extra_compile_args=extra_compile_args, + extra_link_args=extra_link_args) extensions.append(ext_ops) return extensions diff --git a/tests/test_ops/test_roi_align.py b/tests/test_ops/test_roi_align.py index 9609f37c93..46a5183f30 100644 --- a/tests/test_ops/test_roi_align.py +++ b/tests/test_ops/test_roi_align.py @@ -93,6 +93,15 @@ def _test_roialign_allclose(device, dtype): x.grad.data.type(torch.float).cpu().numpy(), np_grad, atol=1e-3) +@pytest.mark.parametrize('dtype', [ + torch.float, + pytest.param( + torch.double, + marks=pytest.mark.skipif( + IS_MLU_AVAILABLE or IS_NPU_AVAILABLE, + reason='MLU and NPU do not support for 64-bit floating point')), + torch.half +]) @pytest.mark.parametrize('device', [ 'cpu', pytest.param( @@ -108,15 +117,6 @@ def _test_roialign_allclose(device, dtype): marks=pytest.mark.skipif( not IS_NPU_AVAILABLE, reason='requires NPU support')) ]) -@pytest.mark.parametrize('dtype', [ - torch.float, - pytest.param( - torch.double, - marks=pytest.mark.skipif( - IS_MLU_AVAILABLE or IS_NPU_AVAILABLE, - reason='MLU and NPU do not support for 64-bit floating point')), - torch.half -]) def test_roialign(device, dtype): # check double only if dtype is torch.double: