Skip to content

Commit

Permalink
Optimize the performance of NMSFast in the multiclass_nms3_kernel. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghonggeng authored Mar 10, 2025
1 parent 8ad9604 commit e57f38f
Showing 1 changed file with 24 additions and 25 deletions.
49 changes: 24 additions & 25 deletions paddle/phi/kernels/cpu/multiclass_nms3_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -302,37 +302,36 @@ void NMSFast(const DenseTensor& bbox,
T adaptive_threshold = nms_threshold;
const T* bbox_data = bbox.data<T>();

while (!sorted_indices.empty()) {
const int idx = sorted_indices.front().second;
size_t num_indices = sorted_indices.size();
selected_indices->reserve(num_indices);
for (size_t i = 0; i < num_indices; ++i) {
const int idx = sorted_indices[i].second;
bool keep = true;
for (const auto kept_idx : *selected_indices) {
if (keep) {
T overlap = T(0.);
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
normalized);
}
// 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
if (box_size == 8 || box_size == 16 || box_size == 24 ||
box_size == 32) {
overlap = PolyIoU<T>(bbox_data + idx * box_size,
bbox_data + kept_idx * box_size,
box_size,
normalized);
}
keep = overlap <= adaptive_threshold;
} else {
const T* current_bbox = bbox_data + idx * box_size;
size_t selected_size = selected_indices->size();

for (size_t j = 0; j < selected_size; ++j) {
const auto kept_idx = (*selected_indices)[j];
T overlap = T(0.);
const T* kept_bbox = bbox_data + kept_idx * box_size;
// 4: [xmin ymin xmax ymax]
if (box_size == 4) {
overlap = JaccardOverlap<T>(current_bbox, kept_bbox, normalized);
} else if (box_size == 8 || box_size == 16 || box_size == 24 ||
box_size ==
32) { // 8: [x1 y1 x2 y2 x3 y3 x4 y4] or 16, 24, 32
overlap = PolyIoU<T>(current_bbox, kept_bbox, box_size, normalized);
}
keep = overlap <= adaptive_threshold;
if (!keep) {
break;
}
}
if (keep) {
selected_indices->push_back(idx);
}
sorted_indices.erase(sorted_indices.begin());
if (keep && eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
if (eta < 1 && adaptive_threshold > 0.5) {
adaptive_threshold *= eta;
}
}
}
}
Expand Down

0 comments on commit e57f38f

Please sign in to comment.