Skip to content

Commit d07307a

Browse files
nikhilaravifacebook-github-bot
authored andcommitted
Non square image rasterization for meshes
Summary: There are a couple of options for supporting non square images: 1) NDC stays at [-1, 1] in both directions with the distance calculations all modified by (W/H). There are a lot of distance based calculations (e.g. triangle areas for barycentric coordinates etc) so this requires changes in many places. 2) NDC is scaled by (W/H) so the smallest side has [-1, 1]. In this case none of the distance calculations need to be updated and only the pixel to NDC calculation needs to be modified. I decided to go with option 2 after trying option 1! API Changes: - Image size can now be specified optionally as a tuple TODO: - add a benchmark test for the non square case. Reviewed By: jcjohnson Differential Revision: D24404975 fbshipit-source-id: 545efb67c822d748ec35999b35762bce58db2cf4
1 parent 0216e46 commit d07307a

13 files changed

+774
-115
lines changed

docs/notes/renderer_getting_started.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ While we tried to emulate several aspects of OpenGL, there are differences in th
5555

5656
---
5757

58+
### Rasterizing Non Square Images
59+
60+
To rasterize an image where H != W, you can specify the `image_size` in the `RasterizationSettings` as a tuple of (H, W).
61+
62+
The aspect ratio needs special consideration. There are two aspect ratios to be aware of:
63+
- the aspect ratio of each pixel
64+
- the aspect ratio of the output image
65+
In the cameras e.g. `FoVPerspectiveCameras`, the `aspect_ratio` argument can be used to set the pixel aspect ratio. In the rasterizer, we assume square pixels, but variable image aspect ratio (i.e rectangle images).
66+
67+
In most cases you will want to set the camera aspect ratio to 1.0 (i.e. square pixels) and only vary the `image_size` in the `RasterizationSettings`(i.e. the output image dimensions in pixels).
68+
69+
---
70+
5871
### The pulsar backend
5972

6073
Since v0.3, [pulsar](https://arxiv.org/abs/2004.07484) can be used as a backend for point-rendering. It has a focus on efficiency, which comes with pros and cons: it is highly optimized and all rendering stages are integrated in the CUDA kernels. This leads to significantly higher speed and better scaling behavior. We use it at Facebook Reality Labs to render and optimize scenes with millions of spheres in resolutions up to 4K. You can find a runtime comparison plot below (settings: `bin_size=None`, `points_per_pixel=5`, `image_size=1024`, `radius=1e-2`, `composite_params.radius=1e-4`; benchmarked on an RTX 2070 GPU).
@@ -75,6 +88,8 @@ For mesh texturing we offer several options (in `pytorch3d/renderer/mesh/texturi
7588

7689
<img src="assets/texturing.jpg" width="1000">
7790

91+
---
92+
7893
### A simple renderer
7994

8095
A renderer in PyTorch3D is composed of a **rasterizer** and a **shader**. Create a renderer in a few simple steps:
@@ -108,6 +123,8 @@ renderer = MeshRenderer(
108123
)
109124
```
110125

126+
---
127+
111128
### A custom shader
112129

113130
Shaders are the most flexible part of the PyTorch3D rendering API. We have created some examples of shaders in `shaders.py` but this is a non exhaustive set.

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

Lines changed: 83 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
234234
const int xi = W - 1 - pix_idx % W;
235235

236236
// screen coordinates to ndc coordiantes of pixel.
237-
const float xf = PixToNdc(xi, W);
238-
const float yf = PixToNdc(yi, H);
237+
const float xf = PixToNonSquareNdc(xi, W, H);
238+
const float yf = PixToNonSquareNdc(yi, H, W);
239239
const float2 pxy = make_float2(xf, yf);
240240

241241
// For keeping track of the K closest points we want a data structure
@@ -262,6 +262,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
262262
for (int f = face_start_idx; f < face_stop_idx; ++f) {
263263
// Check if the pixel pxy is inside the face bounding box and if it is,
264264
// update q, q_size, q_max_z and q_max_idx in place.
265+
265266
CheckPixelInsideFace(
266267
face_verts,
267268
f,
@@ -280,6 +281,7 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
280281
// TODO: make sorting an option as only top k is needed, not sorted values.
281282
BubbleSort(q, q_size);
282283
int idx = n * H * W * K + pix_idx * K;
284+
283285
for (int k = 0; k < q_size; ++k) {
284286
face_idxs[idx + k] = q[k].idx;
285287
zbuf[idx + k] = q[k].z;
@@ -296,7 +298,7 @@ RasterizeMeshesNaiveCuda(
296298
const at::Tensor& face_verts,
297299
const at::Tensor& mesh_to_faces_packed_first_idx,
298300
const at::Tensor& num_faces_per_mesh,
299-
const int image_size,
301+
const std::tuple<int, int> image_size,
300302
const float blur_radius,
301303
const int num_closest,
302304
const bool perspective_correct,
@@ -332,8 +334,8 @@ RasterizeMeshesNaiveCuda(
332334
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
333335

334336
const int N = num_faces_per_mesh.size(0); // batch size.
335-
const int H = image_size; // Assume square images.
336-
const int W = image_size;
337+
const int H = std::get<0>(image_size);
338+
const int W = std::get<1>(image_size);
337339
const int K = num_closest;
338340

339341
auto long_opts = num_faces_per_mesh.options().dtype(at::kLong);
@@ -405,8 +407,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
405407
const int yi = H - 1 - pix_idx / W;
406408
const int xi = W - 1 - pix_idx % W;
407409

408-
const float xf = PixToNdc(xi, W);
409-
const float yf = PixToNdc(yi, H);
410+
const float xf = PixToNonSquareNdc(xi, W, H);
411+
const float yf = PixToNonSquareNdc(yi, H, W);
410412
const float2 pxy = make_float2(xf, yf);
411413

412414
// Loop over all the faces for this pixel.
@@ -589,12 +591,25 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
589591
int* bin_faces) {
590592
extern __shared__ char sbuf[];
591593
const int M = max_faces_per_bin;
592-
const int num_bins = 1 + (W - 1) / bin_size; // Integer divide round up
593-
const float half_pix = 1.0f / W; // Size of half a pixel in NDC units
594+
// Integer divide round up
595+
const int num_bins_x = 1 + (W - 1) / bin_size;
596+
const int num_bins_y = 1 + (H - 1) / bin_size;
597+
598+
// NDC range depends on the ratio of W/H
599+
// The shorter side from (H, W) is given an NDC range of 2.0 and
600+
// the other side is scaled by the ratio of H:W.
601+
const float NDC_x_half_range = NonSquareNdcRange(W, H) / 2.0f;
602+
const float NDC_y_half_range = NonSquareNdcRange(H, W) / 2.0f;
603+
604+
// Size of half a pixel in NDC units is the NDC half range
605+
// divided by the corresponding image dimension
606+
const float half_pix_x = NDC_x_half_range / W;
607+
const float half_pix_y = NDC_y_half_range / H;
608+
594609
// This is a boolean array of shape (num_bins, num_bins, chunk_size)
595610
// stored in shared memory that will track whether each point in the chunk
596611
// falls into each bin of the image.
597-
BitMask binmask((unsigned int*)sbuf, num_bins, num_bins, chunk_size);
612+
BitMask binmask((unsigned int*)sbuf, num_bins_y, num_bins_x, chunk_size);
598613

599614
// Have each block handle a chunk of faces
600615
const int chunks_per_batch = 1 + (F - 1) / chunk_size;
@@ -641,21 +656,24 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
641656
}
642657

643658
// Brute-force search over all bins; TODO(T54294966) something smarter.
644-
for (int by = 0; by < num_bins; ++by) {
659+
for (int by = 0; by < num_bins_y; ++by) {
645660
// Y coordinate of the top and bottom of the bin.
646661
// PixToNdc gives the location of the center of each pixel, so we
647662
// need to add/subtract a half pixel to get the true extent of the bin.
648663
// Reverse ordering of Y axis so that +Y is upwards in the image.
649-
const float bin_y_min = PixToNdc(by * bin_size, H) - half_pix;
650-
const float bin_y_max = PixToNdc((by + 1) * bin_size - 1, H) + half_pix;
664+
const float bin_y_min =
665+
PixToNonSquareNdc(by * bin_size, H, W) - half_pix_y;
666+
const float bin_y_max =
667+
PixToNonSquareNdc((by + 1) * bin_size - 1, H, W) + half_pix_y;
651668
const bool y_overlap = (ymin <= bin_y_max) && (bin_y_min < ymax);
652669

653-
for (int bx = 0; bx < num_bins; ++bx) {
670+
for (int bx = 0; bx < num_bins_x; ++bx) {
654671
// X coordinate of the left and right of the bin.
655672
// Reverse ordering of x axis so that +X is left.
656673
const float bin_x_max =
657-
PixToNdc((bx + 1) * bin_size - 1, W) + half_pix;
658-
const float bin_x_min = PixToNdc(bx * bin_size, W) - half_pix;
674+
PixToNonSquareNdc((bx + 1) * bin_size - 1, W, H) + half_pix_x;
675+
const float bin_x_min =
676+
PixToNonSquareNdc(bx * bin_size, W, H) - half_pix_x;
659677

660678
const bool x_overlap = (xmin <= bin_x_max) && (bin_x_min < xmax);
661679
if (y_overlap && x_overlap) {
@@ -668,12 +686,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
668686
// Now we have processed every face in the current chunk. We need to
669687
// count the number of faces in each bin so we can write the indices
670688
// out to global memory. We have each thread handle a different bin.
671-
for (int byx = threadIdx.x; byx < num_bins * num_bins; byx += blockDim.x) {
672-
const int by = byx / num_bins;
673-
const int bx = byx % num_bins;
689+
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
690+
byx += blockDim.x) {
691+
const int by = byx / num_bins_x;
692+
const int bx = byx % num_bins_x;
674693
const int count = binmask.count(by, bx);
675694
const int faces_per_bin_idx =
676-
batch_idx * num_bins * num_bins + by * num_bins + bx;
695+
batch_idx * num_bins_y * num_bins_x + by * num_bins_x + bx;
677696

678697
// This atomically increments the (global) number of faces found
679698
// in the current bin, and gets the previous value of the counter;
@@ -683,8 +702,8 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
683702

684703
// Now loop over the binmask and write the active bits for this bin
685704
// out to bin_faces.
686-
int next_idx = batch_idx * num_bins * num_bins * M + by * num_bins * M +
687-
bx * M + start;
705+
int next_idx = batch_idx * num_bins_y * num_bins_x * M +
706+
by * num_bins_x * M + bx * M + start;
688707
for (int f = 0; f < chunk_size; ++f) {
689708
if (binmask.get(by, bx, f)) {
690709
// TODO(T54296346) find the correct method for handling errors in
@@ -703,7 +722,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
703722
const at::Tensor& face_verts,
704723
const at::Tensor& mesh_to_face_first_idx,
705724
const at::Tensor& num_faces_per_mesh,
706-
const int image_size,
725+
const std::tuple<int, int> image_size,
707726
const float blur_radius,
708727
const int bin_size,
709728
const int max_faces_per_bin) {
@@ -725,29 +744,35 @@ at::Tensor RasterizeMeshesCoarseCuda(
725744
at::cuda::CUDAGuard device_guard(face_verts.device());
726745
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
727746

728-
const int W = image_size;
729-
const int H = image_size;
747+
const int H = std::get<0>(image_size);
748+
const int W = std::get<1>(image_size);
749+
730750
const int F = face_verts.size(0);
731751
const int N = num_faces_per_mesh.size(0);
732-
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
733752
const int M = max_faces_per_bin;
734753

735-
if (num_bins >= kMaxFacesPerBin) {
754+
// Integer divide round up.
755+
const int num_bins_y = 1 + (H - 1) / bin_size;
756+
const int num_bins_x = 1 + (W - 1) / bin_size;
757+
758+
if (num_bins_y >= kMaxFacesPerBin || num_bins_x >= kMaxFacesPerBin) {
736759
std::stringstream ss;
737-
ss << "Got " << num_bins << "; that's too many!";
760+
ss << "In Coarse Rasterizer got num_bins_y: " << num_bins_y
761+
<< ", num_bins_x: " << num_bins_x << ", "
762+
<< "; that's too many!";
738763
AT_ERROR(ss.str());
739764
}
740765
auto opts = num_faces_per_mesh.options().dtype(at::kInt);
741-
at::Tensor faces_per_bin = at::zeros({N, num_bins, num_bins}, opts);
742-
at::Tensor bin_faces = at::full({N, num_bins, num_bins, M}, -1, opts);
766+
at::Tensor faces_per_bin = at::zeros({N, num_bins_y, num_bins_x}, opts);
767+
at::Tensor bin_faces = at::full({N, num_bins_y, num_bins_x, M}, -1, opts);
743768

744769
if (bin_faces.numel() == 0) {
745770
AT_CUDA_CHECK(cudaGetLastError());
746771
return bin_faces;
747772
}
748773

749774
const int chunk_size = 512;
750-
const size_t shared_size = num_bins * num_bins * chunk_size / 8;
775+
const size_t shared_size = num_bins_y * num_bins_x * chunk_size / 8;
751776
const size_t blocks = 64;
752777
const size_t threads = 512;
753778

@@ -782,7 +807,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
782807
const bool clip_barycentric_coords,
783808
const bool cull_backfaces,
784809
const int N,
785-
const int B,
810+
const int BH,
811+
const int BW,
786812
const int M,
787813
const int H,
788814
const int W,
@@ -793,7 +819,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
793819
float* bary // (N, S, S, K, 3)
794820
) {
795821
// This can be more than S^2 if S % bin_size != 0
796-
int num_pixels = N * B * B * bin_size * bin_size;
822+
int num_pixels = N * BH * BW * bin_size * bin_size;
797823
int num_threads = gridDim.x * blockDim.x;
798824
int tid = blockIdx.x * blockDim.x + threadIdx.x;
799825

@@ -803,20 +829,26 @@ __global__ void RasterizeMeshesFineCudaKernel(
803829
// into the same bin; this should give them coalesced memory reads when
804830
// they read from faces and bin_faces.
805831
int i = pid;
806-
const int n = i / (B * B * bin_size * bin_size);
807-
i %= B * B * bin_size * bin_size;
808-
const int by = i / (B * bin_size * bin_size);
809-
i %= B * bin_size * bin_size;
832+
const int n = i / (BH * BW * bin_size * bin_size);
833+
i %= BH * BW * bin_size * bin_size;
834+
// bin index y
835+
const int by = i / (BW * bin_size * bin_size);
836+
i %= BW * bin_size * bin_size;
837+
// bin index y
810838
const int bx = i / (bin_size * bin_size);
839+
// pixel within the bin
811840
i %= bin_size * bin_size;
841+
842+
// Pixel x, y indices
812843
const int yi = i / bin_size + by * bin_size;
813844
const int xi = i % bin_size + bx * bin_size;
814845

815846
if (yi >= H || xi >= W)
816847
continue;
817848

818-
const float xf = PixToNdc(xi, W);
819-
const float yf = PixToNdc(yi, H);
849+
const float xf = PixToNonSquareNdc(xi, W, H);
850+
const float yf = PixToNonSquareNdc(yi, H, W);
851+
820852
const float2 pxy = make_float2(xf, yf);
821853

822854
// This part looks like the naive rasterization kernel, except we use
@@ -828,7 +860,7 @@ __global__ void RasterizeMeshesFineCudaKernel(
828860
float q_max_z = -1000;
829861
int q_max_idx = -1;
830862
for (int m = 0; m < M; m++) {
831-
const int f = bin_faces[n * B * B * M + by * B * M + bx * M + m];
863+
const int f = bin_faces[n * BH * BW * M + by * BW * M + bx * M + m];
832864
if (f < 0) {
833865
continue; // bin_faces uses -1 as a sentinal value.
834866
}
@@ -858,7 +890,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
858890
// in the image +Y is pointing up and +X is pointing left.
859891
const int yidx = H - 1 - yi;
860892
const int xidx = W - 1 - xi;
861-
const int pix_idx = n * H * W * K + yidx * H * K + xidx * K;
893+
894+
const int pix_idx = n * H * W * K + yidx * W * K + xidx * K;
862895
for (int k = 0; k < q_size; k++) {
863896
face_idxs[pix_idx + k] = q[k].idx;
864897
zbuf[pix_idx + k] = q[k].z;
@@ -874,7 +907,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
874907
RasterizeMeshesFineCuda(
875908
const at::Tensor& face_verts,
876909
const at::Tensor& bin_faces,
877-
const int image_size,
910+
const std::tuple<int, int> image_size,
878911
const float blur_radius,
879912
const int bin_size,
880913
const int faces_per_pixel,
@@ -897,12 +930,15 @@ RasterizeMeshesFineCuda(
897930
at::cuda::CUDAGuard device_guard(face_verts.device());
898931
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
899932

933+
// bin_faces shape (N, BH, BW, M)
900934
const int N = bin_faces.size(0);
901-
const int B = bin_faces.size(1);
935+
const int BH = bin_faces.size(1);
936+
const int BW = bin_faces.size(2);
902937
const int M = bin_faces.size(3);
903938
const int K = faces_per_pixel;
904-
const int H = image_size; // Assume square images only.
905-
const int W = image_size;
939+
940+
const int H = std::get<0>(image_size);
941+
const int W = std::get<1>(image_size);
906942

907943
if (K > kMaxPointsPerPixel) {
908944
AT_ERROR("Must have num_closest <= 150");
@@ -932,7 +968,8 @@ RasterizeMeshesFineCuda(
932968
clip_barycentric_coords,
933969
cull_backfaces,
934970
N,
935-
B,
971+
BH,
972+
BW,
936973
M,
937974
H,
938975
W,

0 commit comments

Comments
 (0)