Skip to content

Commit 65620e7

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Adding support for changing background color
Summary: Adds support to hard_rgb_blend and hard blending shaders in shader.py (HardPhongShader, HardGouraudShader, and HardFlatShader) for changing the background color on which objects are rendered Reviewed By: nikhilaravi Differential Revision: D21746062 fbshipit-source-id: 08001200f4339d6a69c52405c6b8f4cac9f3f56e
1 parent e3819a4 commit 65620e7

File tree

4 files changed

+75
-23
lines changed

4 files changed

+75
-23
lines changed

pytorch3d/renderer/blending.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class BlendParams(NamedTuple):
1919
background_color: Sequence = (1.0, 1.0, 1.0)
2020

2121

22-
def hard_rgb_blend(colors, fragments) -> torch.Tensor:
22+
def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
2323
"""
2424
Naive blending of top K faces to return an RGBA image
2525
- **RGB** - choose color of the closest point i.e. K=0
@@ -32,14 +32,31 @@ def hard_rgb_blend(colors, fragments) -> torch.Tensor:
3232
of the faces (in the packed representation) which
3333
overlap each pixel in the image. This is used to
3434
determine the output shape.
35+
blend_params: BlendParams instance that contains a background_color
36+
field specifying the color for the background
3537
Returns:
3638
RGBA pixel_colors: (N, H, W, 4)
3739
"""
3840
N, H, W, K = fragments.pix_to_face.shape
3941
device = fragments.pix_to_face.device
40-
pixel_colors = torch.ones((N, H, W, 4), dtype=colors.dtype, device=device)
41-
pixel_colors[..., :3] = colors[..., 0, :]
42-
return pixel_colors
42+
43+
# Mask for the background.
44+
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
45+
46+
background_color = colors.new_tensor(blend_params.background_color) # (3)
47+
48+
# Find out how much background_color needs to be expanded to be used for masked_scatter.
49+
num_background_pixels = is_background.sum()
50+
51+
# Set background color.
52+
pixel_colors = colors[..., 0, :].masked_scatter(
53+
is_background[..., None],
54+
background_color[None, :].expand(num_background_pixels, -1),
55+
) # (N, H, W, 3)
56+
57+
# Concat with the alpha channel.
58+
alpha = torch.ones((N, H, W, 1), dtype=colors.dtype, device=device)
59+
return torch.cat([pixel_colors, alpha], dim=-1) # (N, H, W, 4)
4360

4461

4562
def sigmoid_alpha_blend(colors, fragments, blend_params) -> torch.Tensor:

pytorch3d/renderer/mesh/shader.py

+18-6
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@ class HardPhongShader(nn.Module):
3939
shader = HardPhongShader(device=torch.device("cuda:0"))
4040
"""
4141

42-
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
42+
def __init__(
43+
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
44+
):
4345
super().__init__()
4446
self.lights = lights if lights is not None else PointLights(device=device)
4547
self.materials = (
4648
materials if materials is not None else Materials(device=device)
4749
)
4850
self.cameras = cameras
51+
self.blend_params = blend_params if blend_params is not None else BlendParams()
4952

5053
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
5154
cameras = kwargs.get("cameras", self.cameras)
@@ -57,6 +60,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
5760
texels = interpolate_vertex_colors(fragments, meshes)
5861
lights = kwargs.get("lights", self.lights)
5962
materials = kwargs.get("materials", self.materials)
63+
blend_params = kwargs.get("blend_params", self.blend_params)
6064
colors = phong_shading(
6165
meshes=meshes,
6266
fragments=fragments,
@@ -65,7 +69,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
6569
cameras=cameras,
6670
materials=materials,
6771
)
68-
images = hard_rgb_blend(colors, fragments)
72+
images = hard_rgb_blend(colors, fragments, blend_params)
6973
return images
7074

7175

@@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module):
130134
shader = HardGouraudShader(device=torch.device("cuda:0"))
131135
"""
132136

133-
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
137+
def __init__(
138+
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
139+
):
134140
super().__init__()
135141
self.lights = lights if lights is not None else PointLights(device=device)
136142
self.materials = (
137143
materials if materials is not None else Materials(device=device)
138144
)
139145
self.cameras = cameras
146+
self.blend_params = blend_params if blend_params is not None else BlendParams()
140147

141148
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
142149
cameras = kwargs.get("cameras", self.cameras)
@@ -146,14 +153,15 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
146153
raise ValueError(msg)
147154
lights = kwargs.get("lights", self.lights)
148155
materials = kwargs.get("materials", self.materials)
156+
blend_params = kwargs.get("blend_params", self.blend_params)
149157
pixel_colors = gouraud_shading(
150158
meshes=meshes,
151159
fragments=fragments,
152160
lights=lights,
153161
cameras=cameras,
154162
materials=materials,
155163
)
156-
images = hard_rgb_blend(pixel_colors, fragments)
164+
images = hard_rgb_blend(pixel_colors, fragments, blend_params)
157165
return images
158166

159167

@@ -266,13 +274,16 @@ class HardFlatShader(nn.Module):
266274
shader = HardFlatShader(device=torch.device("cuda:0"))
267275
"""
268276

269-
def __init__(self, device="cpu", cameras=None, lights=None, materials=None):
277+
def __init__(
278+
self, device="cpu", cameras=None, lights=None, materials=None, blend_params=None
279+
):
270280
super().__init__()
271281
self.lights = lights if lights is not None else PointLights(device=device)
272282
self.materials = (
273283
materials if materials is not None else Materials(device=device)
274284
)
275285
self.cameras = cameras
286+
self.blend_params = blend_params if blend_params is not None else BlendParams()
276287

277288
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
278289
cameras = kwargs.get("cameras", self.cameras)
@@ -283,6 +294,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
283294
texels = interpolate_vertex_colors(fragments, meshes)
284295
lights = kwargs.get("lights", self.lights)
285296
materials = kwargs.get("materials", self.materials)
297+
blend_params = kwargs.get("blend_params", self.blend_params)
286298
colors = flat_shading(
287299
meshes=meshes,
288300
fragments=fragments,
@@ -291,7 +303,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
291303
cameras=cameras,
292304
materials=materials,
293305
)
294-
images = hard_rgb_blend(colors, fragments)
306+
images = hard_rgb_blend(colors, fragments, blend_params)
295307
return images
296308

297309

tests/test_blending.py

+18-10
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import torch
7+
from common_testing import TestCaseMixin
78
from pytorch3d.renderer.blending import (
89
BlendParams,
910
hard_rgb_blend,
@@ -128,7 +129,7 @@ def softmax_blend_naive(colors, fragments, blend_params):
128129
return pixel_colors
129130

130131

131-
class TestBlending(unittest.TestCase):
132+
class TestBlending(TestCaseMixin, unittest.TestCase):
132133
def setUp(self) -> None:
133134
torch.manual_seed(42)
134135

@@ -156,22 +157,29 @@ def _compare_impls(
156157

157158
def test_hard_rgb_blend(self):
158159
N, H, W, K = 5, 10, 10, 20
159-
pix_to_face = torch.ones((N, H, W, K))
160+
pix_to_face = torch.randint(low=-1, high=100, size=(N, H, W, K))
160161
bary_coords = torch.ones((N, H, W, K, 3))
161162
fragments = Fragments(
162163
pix_to_face=pix_to_face,
163164
bary_coords=bary_coords,
164165
zbuf=pix_to_face, # dummy
165166
dists=pix_to_face, # dummy
166167
)
167-
colors = bary_coords.clone()
168-
top_k = torch.randn((K, 3))
169-
colors[..., :, :] = top_k
170-
images = hard_rgb_blend(colors, fragments)
171-
expected_vals = torch.ones((N, H, W, 4))
172-
pix_cols = torch.ones_like(expected_vals[..., :3]) * top_k[0, :]
173-
expected_vals[..., :3] = pix_cols
174-
self.assertTrue(torch.allclose(images, expected_vals))
168+
colors = torch.randn((N, H, W, K, 3))
169+
blend_params = BlendParams(1e-4, 1e-4, (0.5, 0.5, 1))
170+
images = hard_rgb_blend(colors, fragments, blend_params)
171+
172+
# Examine if the foreground colors are correct.
173+
is_foreground = pix_to_face[..., 0] >= 0
174+
self.assertClose(images[is_foreground][:, :3], colors[is_foreground][..., 0, :])
175+
176+
# Examine if the background colors are correct.
177+
for i in range(3): # i.e. RGB
178+
channel_color = blend_params.background_color[i]
179+
self.assertTrue(images[~is_foreground][..., i].eq(channel_color).all())
180+
181+
# Examine the alpha channel is correct
182+
self.assertTrue(images[..., 3].eq(1).all())
175183

176184
def test_sigmoid_alpha_blend_manual_gradients(self):
177185
# Create dummy outputs of rasterization

tests/test_render_meshes.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def test_simple_sphere(self, elevated_camera=False):
7777
image_size=512, blur_radius=0.0, faces_per_pixel=1
7878
)
7979
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
80+
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
8081

8182
# Test several shaders
8283
shaders = {
@@ -85,7 +86,12 @@ def test_simple_sphere(self, elevated_camera=False):
8586
"flat": HardFlatShader,
8687
}
8788
for (name, shader_init) in shaders.items():
88-
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
89+
shader = shader_init(
90+
lights=lights,
91+
cameras=cameras,
92+
materials=materials,
93+
blend_params=blend_params,
94+
)
8995
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
9096
images = renderer(sphere_mesh)
9197
filename = "simple_sphere_light_%s%s.png" % (name, postfix)
@@ -105,7 +111,10 @@ def test_simple_sphere(self, elevated_camera=False):
105111
########################################################
106112
lights.location[..., 2] = -2.0
107113
phong_shader = HardPhongShader(
108-
lights=lights, cameras=cameras, materials=materials
114+
lights=lights,
115+
cameras=cameras,
116+
materials=materials,
117+
blend_params=blend_params,
109118
)
110119
phong_renderer = MeshRenderer(rasterizer=rasterizer, shader=phong_shader)
111120
images = phong_renderer(sphere_mesh, lights=lights)
@@ -162,6 +171,7 @@ def test_simple_sphere_batched(self):
162171
materials = Materials(device=device)
163172
lights = PointLights(device=device)
164173
lights.location = torch.tensor([0.0, 0.0, +2.0], device=device)[None]
174+
blend_params = BlendParams(1e-4, 1e-4, (0, 0, 0))
165175

166176
# Init renderer
167177
rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings)
@@ -171,7 +181,12 @@ def test_simple_sphere_batched(self):
171181
"flat": HardFlatShader,
172182
}
173183
for (name, shader_init) in shaders.items():
174-
shader = shader_init(lights=lights, cameras=cameras, materials=materials)
184+
shader = shader_init(
185+
lights=lights,
186+
cameras=cameras,
187+
materials=materials,
188+
blend_params=blend_params,
189+
)
175190
renderer = MeshRenderer(rasterizer=rasterizer, shader=shader)
176191
images = renderer(sphere_meshes)
177192
image_ref = load_rgb_image(

0 commit comments

Comments
 (0)