Skip to content

[AUTOGENERATED] [release/2.6] [ROCm] Enable more parallelism for multi-dimensional reductions #2297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions aten/src/ATen/native/cuda/Reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1115,13 +1115,19 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
int max_threads_per_mp =
at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor;
#ifdef USE_ROCM
// Control the number of threadblocks by adjusting the maximum number of
// threads per multi-processor. These numbers better reflect the maximum
// theoretical achievable threads per MP for the reduction operation.
if (iter.ndim() == 1 || iter.ndim() == 3)
max_threads_per_mp = 512;
if (iter.ndim() == 2)
max_threads_per_mp = 256;
// If the grid consists of a single threadblock, do not change the max threads per
// MP value. This will increase the parallelism across the y dimension of the grid.
bool uses_a_single_block = config.grid().x == config.grid().y == config.grid().z == 1;

if (!uses_a_single_block) {
// Control the number of threadblocks by adjusting the maximum number of
// threads per multi-processor. These numbers better reflect the maximum
// theoretical achievable threads per MP for the reduction operation.
if (iter.ndim() == 1 || iter.ndim() == 3)
max_threads_per_mp = 512;
else if (iter.ndim() == 2)
max_threads_per_mp = 256;
}
#endif
const int blocks_per_sm = max_threads_per_mp / config.num_threads;
const int target_grid_size = num_mp * blocks_per_sm;
Expand Down