Skip to content

[NVIDIA] Fix test_segment_ops unit test failed on V100 #38113

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 17, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions paddle/fluid/operators/math/segment_pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
}

template <typename T, typename Index, typename Helper, typename Pool>
__global__ void SegmentOpsKernel(const Index* segment_ids, const T* input,
T* output, Helper h, Pool pool) {
__global__ void __launch_bounds__(1024, 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some other kernels in segment_pooling.cu, such as SegmentMeanKernel, share the same launch config.
Dose other kernel function which have the same launch config may cause the same problem?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dose other kernel function which have the same launch config may cause the same problem?

There might be, but currently no other kernels encounter the same problem on V100. NVCC doesn't know the runtime launch config, so it doesn't limit the registers usage . For example, this kernel can be ran with 128/256/512 threads per block, if NVCC limits the registers usage, it may reduce the performance of above configurations.

BTW, 1024 threads per block results lower performance than 128/256 threads per block from my experiences. CUDA Best Practices also says that

Between 128 and 256 threads per block is a good initial range for experimentation with different block sizes.

However, it may be a large effort to do performance benchmarks and verifications on each op used this launch config.

SegmentOpsKernel(const Index* segment_ids, const T* input, T* output,
Helper h, Pool pool) {
CUDA_KERNEL_LOOP(stripe_index, h.total_stripe_count) {
Index segment_offset, dim_index_base, actual_height;
Index inner_dim_size = h.inner_dim_size;
Expand Down