Skip to content

[rocm7.0_internal_testing] Prevent static initialization of at::cuda::warp_size() #2293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,

int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads(),
"BlockReduceSum requires all warps be active");
const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
dim3 grid = unique_indices.numel();
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/MultinomialKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void renormRows(Tensor& t) {
TORCH_CHECK(props != nullptr);
int numSM = props->multiProcessorCount;
const int64_t maxThreads = std::min(
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads);
props->maxThreadsPerBlock, cuda_utils::kCUDABlockReduceMaxThreads());

int warp_size = at::cuda::warp_size();
dim3 grid(rows < numSM * 4 ? rows : numSM * 4);
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/TensorModeKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ void handle_fused_mode(
constexpr int num_threads = size / 2;
int warp_size = at::cuda::warp_size();
TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0 &&
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, "");
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads(), "");
const auto memsize =
(sizeof(scalar_t) * size) + (2 * size * sizeof(unsigned int));
compute_mode<scalar_t, size>
Expand Down
8 changes: 6 additions & 2 deletions aten/src/ATen/native/cuda/block_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@ constexpr int kCUDABlockReduceNumThreads = 512;
// ROCm NOTE: C10_WARP_SIZE should only be used inside device functions,
// and kCUDABlockReduceMaxThreads is a host-side variable.
#ifdef USE_ROCM
static const int kCUDABlockReduceMaxThreads = at::cuda::warp_size() * at::cuda::warp_size();
static int kCUDABlockReduceMaxThreads() {
return at::cuda::warp_size() * at::cuda::warp_size();
}
#else
constexpr int kCUDABlockReduceMaxThreads = C10_WARP_SIZE * C10_WARP_SIZE;
constexpr int kCUDABlockReduceMaxThreads() {
return C10_WARP_SIZE * C10_WARP_SIZE;
}
#endif

// Sums `val` across all threads in a warp.
Expand Down