Skip to content

Commit

Permalink
Merge pull request #11 from isl-org/yuanxion/fix-input-numel
Browse files Browse the repository at this point in the history
I'm sorry I did not reply you so long, but I have marged your PR!
  • Loading branch information
tatsy authored Jul 16, 2024
2 parents 72b9129 + bc36b08 commit 05a3fab
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions cxx/mcubes_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,13 @@ __device__ float3 vertexInterp(float isolevel, float3 p1, float3 p2, float valp1
}

__global__ void mcubes_cuda_kernel(
const torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> vol,
torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> vertices,
torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> ntris_in_cells,
const torch::PackedTensorAccessor64<float, 3, torch::RestrictPtrTraits> vol,
torch::PackedTensorAccessor64<float, 5, torch::RestrictPtrTraits> vertices,
torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> ntris_in_cells,
int3 nGrids,
float threshold,
const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> edgeTable,
const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> triTable) {
const torch::PackedTensorAccessor64<int, 1, torch::RestrictPtrTraits> edgeTable,
const torch::PackedTensorAccessor64<int, 2, torch::RestrictPtrTraits> triTable) {

const int ix = blockIdx.x * blockDim.x + threadIdx.x;
const int iy = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -436,12 +436,12 @@ __global__ void mcubes_cuda_kernel(
}

__global__ void compaction(
const torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> vertBuf,
const torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> ntris,
const torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> offsets,
const torch::PackedTensorAccessor64<float, 5, torch::RestrictPtrTraits> vertBuf,
const torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> ntris,
const torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> offsets,
int3 nGrids,
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> verts,
torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> faces) {
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> verts,
torch::PackedTensorAccessor64<int, 2, torch::RestrictPtrTraits> faces) {

const int ix = blockIdx.x * blockDim.x + threadIdx.x;
const int iy = blockIdx.y * blockDim.y + threadIdx.y;
Expand Down Expand Up @@ -521,13 +521,13 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {
// Kernel call
cudaSetDevice(deviceId);
mcubes_cuda_kernel<<<blocks, threads, 0, stream>>>(
vol.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
vert_buffer.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),
ntris_in_cells.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
vol.packed_accessor64<float, 3, torch::RestrictPtrTraits>(),
vert_buffer.packed_accessor64<float, 5, torch::RestrictPtrTraits>(),
ntris_in_cells.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
nGrids,
threshold,
edgeTableTensorCuda.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
triTableTensorCuda.packed_accessor32<int, 2, torch::RestrictPtrTraits>()
edgeTableTensorCuda.packed_accessor64<int, 1, torch::RestrictPtrTraits>(),
triTableTensorCuda.packed_accessor64<int, 2, torch::RestrictPtrTraits>()
);
cudaDeviceSynchronize();

Expand All @@ -549,12 +549,12 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {

cudaSetDevice(deviceId);
compaction<<<blocks, threads, 0, stream>>>(
vert_buffer.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),
ntris_in_cells.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
offsets.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
vert_buffer.packed_accessor64<float, 5, torch::RestrictPtrTraits>(),
ntris_in_cells.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
offsets.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
nGrids,
verts.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
faces.packed_accessor32<int, 2, torch::RestrictPtrTraits>()
verts.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
faces.packed_accessor64<int, 2, torch::RestrictPtrTraits>()
);
cudaDeviceSynchronize();

Expand Down

0 comments on commit 05a3fab

Please sign in to comment.