Skip to content

Try reverting "Fix CUDA kernel index data type ..." #1965

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pytorch3d/csrc/compositing/alpha_composite.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ __global__ void alphaCompositeCudaForwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) {
Expand Down Expand Up @@ -83,11 +83,11 @@ __global__ void alphaCompositeCudaBackwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
Expand Down
12 changes: 6 additions & 6 deletions pytorch3d/csrc/compositing/norm_weighted_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ __global__ void weightedSumNormCudaForwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
Expand Down Expand Up @@ -96,11 +96,11 @@ __global__ void weightedSumNormCudaBackwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * W * H;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
Expand Down
12 changes: 6 additions & 6 deletions pytorch3d/csrc/compositing/weighted_sum.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ __global__ void weightedSumCudaForwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
Expand Down Expand Up @@ -78,11 +78,11 @@ __global__ void weightedSumCudaBackwardKernel(
const int64_t W = points_idx.size(3);

// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;

const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;

// Iterate over each pixel to compute the contribution to the
// gradient for the features and weights
Expand Down
6 changes: 3 additions & 3 deletions pytorch3d/csrc/gather_scatter/gather_scatter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ __global__ void GatherScatterCudaKernel(
const size_t V,
const size_t D,
const size_t E) {
const auto tid = threadIdx.x;
const int tid = threadIdx.x;

// Reverse the vertex order if backward.
const int v0_idx = backward ? 1 : 0;
const int v1_idx = backward ? 0 : 1;

// Edges are split evenly across the blocks.
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
for (int e = blockIdx.x; e < E; e += gridDim.x) {
// Get indices of vertices which form the edge.
const int64_t v0 = edges[2 * e + v0_idx];
const int64_t v1 = edges[2 * e + v1_idx];

// Split vertex features evenly across threads.
// This implementation will be quite wasteful when D<128 since there will be
// a lot of threads doing nothing.
for (auto d = tid; d < D; d += blockDim.x) {
for (int d = tid; d < D; d += blockDim.x) {
const float val = input[v1 * D + d];
float* address = output + v0 * D + d;
atomicAdd(address, val);
Expand Down
8 changes: 4 additions & 4 deletions pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
Expand Down Expand Up @@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
Expand Down
18 changes: 9 additions & 9 deletions pytorch3d/csrc/point_mesh/point_mesh_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
__syncthreads();

// Perform reduction in shared memory.
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
Expand Down Expand Up @@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
const float3* tris_f3 = (float3*)tris;

// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // segment index.
Expand Down Expand Up @@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
const float3* tris_f3 = (float3*)tris;

// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // triangle index.
Expand Down Expand Up @@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
float3* segms_f3 = (float3*)segms;

// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
Expand Down Expand Up @@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
float3* segms_f3 = (float3*)segms;

// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
Expand Down
2 changes: 1 addition & 1 deletion pytorch3d/csrc/rasterize_coarse/bitmask.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BitMask {

// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();
Expand Down
14 changes: 7 additions & 7 deletions pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
const float blur_radius,
float* bboxes, // (4, F)
bool* skip_face) { // (F,)
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const float sqrt_radius = sqrt(blur_radius);
for (int f = tid; f < F; f += num_threads) {
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
Expand Down Expand Up @@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
const int P,
float* bboxes, // (4, P)
bool* skip_points) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
for (int p = tid; p < P; p += num_threads) {
const float x = points[p * 3 + 0];
const float y = points[p * 3 + 1];
Expand Down Expand Up @@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;

for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int elem_chunk_start_idx = chunk_idx * chunk_size;
Expand All @@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];

// Have each thread handle a different face within the chunk
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
const int e_idx = elem_chunk_start_idx + e;

// Check that we are still within the same element of the batch
Expand Down Expand Up @@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
// Now we have processed every elem in the current chunk. We need to
// count the number of elems in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;
Expand Down
12 changes: 6 additions & 6 deletions pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
float* pix_dists,
float* bary) {
// Simple version: One thread per output pixel
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;

for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
Expand Down Expand Up @@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(

// Parallelize over each pixel in images of
// size H * W, for each image in the batch of size N.
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
// Convert linear index to 3D index
Expand Down Expand Up @@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
) {
// This can be more than H * W if H or W are not divisible by bin_size.
int num_pixels = N * BH * BW * bin_size * bin_size;
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
Expand Down
12 changes: 6 additions & 6 deletions pytorch3d/csrc/rasterize_points/rasterize_points.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
float* zbuf, // (N, H, W, K)
float* pix_dists) { // (N, H, W, K)
// Simple version: One thread per output pixel
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (H * W); // Batch index
Expand Down Expand Up @@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
float* pix_dists) { // (N, H, W, K)
// This can be more than H * W if H or W are not divisible by bin_size.
const int num_pixels = N * BH * BW * bin_size * bin_size;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;

for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
Expand Down Expand Up @@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
float* grad_points) { // (P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
// const int n = i / (H * W * K); // batch index (not needed).
const int yxk = i % (H * W * K);
Expand Down
Loading