Skip to content

Commit 9aaba04

Browse files
Steve Bransonfacebook-github-bot
Steve Branson
authored andcommitted
Temporary fix for mesh rasterization bug for traingles partially behind the camera
Summary: A triangle is culled if any vertex in a triangle is behind the camera. This fixes incorrect rendering of triangles that are partially behind the camera, where screen coordinate calculations are strange. It doesn't work for triangles that are partially behind the camera but still intersect with the view frustum. Reviewed By: nikhilaravi Differential Revision: D22856181 fbshipit-source-id: a9cbaa1327d89601b83d0dfd3e4a04f934a4a213
1 parent 57a22e7 commit 9aaba04

File tree

4 files changed

+39
-8
lines changed

4 files changed

+39
-8
lines changed

pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu

+13-4
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,14 @@ __device__ bool CheckPointOutsideBoundingBox(
9090
const float x_max = xlims.y + blur_radius;
9191
const float y_max = ylims.y + blur_radius;
9292

93+
// Faces with at least one vertex behind the camera won't render correctly
94+
// and should be removed or clipped before calling the rasterizer
95+
const bool z_invalid = zlims.x < kEpsilon;
96+
9397
// Check if the current point is oustside the triangle bounding box.
94-
return (pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min);
98+
return (
99+
pxy.x > x_max || pxy.x < x_min || pxy.y > y_max || pxy.y < y_min ||
100+
z_invalid);
95101
}
96102

97103
// This function checks if a pixel given by xy location pxy lies within the
@@ -625,10 +631,13 @@ __global__ void RasterizeMeshesCoarseCudaKernel(
625631
float ymin = FloatMin3(v0.y, v1.y, v2.y) - sqrt(blur_radius);
626632
float xmax = FloatMax3(v0.x, v1.x, v2.x) + sqrt(blur_radius);
627633
float ymax = FloatMax3(v0.y, v1.y, v2.y) + sqrt(blur_radius);
628-
float zmax = FloatMax3(v0.z, v1.z, v2.z);
634+
float zmin = FloatMin3(v0.z, v1.z, v2.z);
629635

630-
if (zmax < 0) {
631-
continue; // Face is behind the camera.
636+
// Faces with at least one vertex behind the camera won't render
637+
// correctly and should be removed or clipped before calling the
638+
// rasterizer
639+
if (zmin < kEpsilon) {
640+
continue;
632641
}
633642

634643
// Brute-force search over all bins; TODO(T54294966) something smarter.

pytorch3d/csrc/rasterize_meshes/rasterize_meshes_cpu.cpp

+11-4
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,12 @@ bool CheckPointOutsideBoundingBox(
6868
float x_max = face_bbox[2] + blur_radius;
6969
float y_max = face_bbox[3] + blur_radius;
7070

71+
// Faces with at least one vertex behind the camera won't render correctly
72+
// and should be removed or clipped before calling the rasterizer
73+
const bool z_invalid = face_bbox[4] < kEpsilon;
74+
7175
// Check if the current point is within the triangle bounding box.
72-
return (px > x_max || px < x_min || py > y_max || py < y_min);
76+
return (px > x_max || px < x_min || py > y_max || py < y_min || z_invalid);
7377
}
7478

7579
// Calculate areas of all faces. Returns a tensor of shape (total_faces, 1)
@@ -468,10 +472,13 @@ torch::Tensor RasterizeMeshesCoarseCpu(
468472
float face_y_min = face_bboxes_a[f][1] - std::sqrt(blur_radius);
469473
float face_x_max = face_bboxes_a[f][2] + std::sqrt(blur_radius);
470474
float face_y_max = face_bboxes_a[f][3] + std::sqrt(blur_radius);
471-
float face_z_max = face_bboxes_a[f][5];
475+
float face_z_min = face_bboxes_a[f][4];
472476

473-
if (face_z_max < 0) {
474-
continue; // Face is behind the camera.
477+
// Faces with at least one vertex behind the camera won't render
478+
// correctly and should be removed or clipped before calling the
479+
// rasterizer
480+
if (face_z_min < kEpsilon) {
481+
continue;
475482
}
476483

477484
// Use a half-open interval so that faces exactly on the

pytorch3d/renderer/mesh/rasterize_meshes.py

+7
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,7 @@ def rasterize_meshes_python(
301301
x_maxs = torch.max(faces_verts[:, :, 0], dim=1, keepdim=True).values
302302
y_mins = torch.min(faces_verts[:, :, 1], dim=1, keepdim=True).values
303303
y_maxs = torch.max(faces_verts[:, :, 1], dim=1, keepdim=True).values
304+
z_mins = torch.min(faces_verts[:, :, 2], dim=1, keepdim=True).values
304305

305306
# Expand by blur radius.
306307
x_mins = x_mins - np.sqrt(blur_radius) - kEpsilon
@@ -351,6 +352,12 @@ def rasterize_meshes_python(
351352
or yf > y_maxs[f]
352353
)
353354

355+
# Faces with at least one vertex behind the camera won't
356+
# render correctly and should be removed or clipped before
357+
# calling the rasterizer
358+
if z_mins[f] < kEpsilon:
359+
continue
360+
354361
# Check if pixel is outside of face bbox.
355362
if outside_bbox:
356363
continue

tests/test_rasterize_meshes.py

+8
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,10 @@ def _compare_impls(
552552
+ (zbuf1 * grad_zbuf).sum()
553553
+ (bary1 * grad_bary).sum()
554554
)
555+
556+
# avoid gradient error if rasterize_meshes_python() culls all triangles
557+
loss1 += grad_var1.sum() * 0.0
558+
555559
loss1.backward()
556560
grad_verts1 = grad_var1.grad.data.clone().cpu()
557561

@@ -563,6 +567,10 @@ def _compare_impls(
563567
+ (zbuf2 * grad_zbuf).sum()
564568
+ (bary2 * grad_bary).sum()
565569
)
570+
571+
# avoid gradient error if rasterize_meshes_python() culls all triangles
572+
loss2 += grad_var2.sum() * 0.0
573+
566574
grad_var1.grad.data.zero_()
567575
loss2.backward()
568576
grad_verts2 = grad_var2.grad.data.clone().cpu()

0 commit comments

Comments
 (0)