Skip to content

Commit 872ff8c

Browse files
Amitav Baruahfacebook-github-bot
Amitav Baruah
authored andcommitted
Add background color support to compositors
Summary: Support rendering different color backgrounds for pointclouds for both compositors Reviewed By: nikhilaravi Differential Revision: D23611043 fbshipit-source-id: ab029650d51349340372c5bd66700e6577d48851
1 parent dc40adf commit 872ff8c

File tree

2 files changed

+125
-2
lines changed

2 files changed

+125
-2
lines changed

pytorch3d/renderer/points/compositor.py

+73-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
import warnings
4+
from typing import List, Optional, Tuple, Union
5+
36
import torch
47
import torch.nn as nn
58

@@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module):
1619
Accumulate points using alpha compositing.
1720
"""
1821

19-
def __init__(self):
22+
def __init__(
23+
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
24+
):
2025
super().__init__()
26+
self.background_color = background_color
2127

2228
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
29+
background_color = kwargs.get("background_color", self.background_color)
2330
images = alpha_composite(fragments, alphas, ptclds)
31+
32+
# images are of shape (N, C, H, W)
33+
# check for background color & feature size C (C=4 indicates rgba)
34+
if background_color is not None and images.shape[1] == 4:
35+
return _add_background_color_to_images(fragments, images, background_color)
2436
return images
2537

2638

@@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module):
2941
Accumulate points using a normalized weighted sum.
3042
"""
3143

32-
def __init__(self):
44+
def __init__(
45+
self, background_color: Optional[Union[Tuple, List, torch.Tensor]] = None
46+
):
3347
super().__init__()
48+
self.background_color = background_color
3449

3550
def forward(self, fragments, alphas, ptclds, **kwargs) -> torch.Tensor:
51+
background_color = kwargs.get("background_color", self.background_color)
3652
images = norm_weighted_sum(fragments, alphas, ptclds)
53+
54+
# images are of shape (N, C, H, W)
55+
# check for background color & feature size C (C=4 indicates rgba)
56+
if background_color is not None and images.shape[1] == 4:
57+
return _add_background_color_to_images(fragments, images, background_color)
3758
return images
59+
60+
61+
def _add_background_color_to_images(pix_idxs, images, background_color):
62+
"""
63+
Mask pixels in images without corresponding points with a given background_color.
64+
65+
Args:
66+
pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
67+
giving the indices of the nearest points at each pixel, sorted in z-order.
68+
images: Tensor of shape (N, 4, image_size, image_size) giving the
69+
accumulated features at each point, where 4 refers to a rgba feature.
70+
background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba
71+
value for the new background. Values should be in the interval [0,1].
72+
Returns:
73+
images: Tensor of shape (N, 4, image_size, image_size), where pixels with
74+
no nearest points have features set to the background color, and other
75+
pixels with accumulated features have unchanged values.
76+
"""
77+
# Initialize background mask
78+
background_mask = pix_idxs[:, 0] < 0 # (N, image_size, image_size)
79+
80+
# Convert background_color to an appropriate tensor and check shape
81+
if not torch.is_tensor(background_color):
82+
background_color = images.new_tensor(background_color)
83+
84+
background_shape = background_color.shape
85+
86+
if len(background_shape) != 1 or background_shape[0] not in (3, 4):
87+
warnings.warn(
88+
"Background color should be size (3) or (4), but is size %s instead"
89+
% (background_shape,)
90+
)
91+
return images
92+
93+
background_color = background_color.to(images)
94+
95+
# add alpha channel
96+
if background_shape[0] == 3:
97+
alpha = images.new_ones(1)
98+
background_color = torch.cat([background_color, alpha])
99+
100+
num_background_pixels = background_mask.sum()
101+
102+
# permute so that features are the last dimension for masked_scatter to work
103+
masked_images = images.permute(0, 2, 3, 1)[..., :4].masked_scatter(
104+
background_mask[..., None],
105+
background_color[None, :].expand(num_background_pixels, -1),
106+
)
107+
108+
return masked_images.permute(0, 3, 1, 2)

tests/test_render_points.py

+52
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
FoVPerspectiveCameras,
1919
look_at_view_transform,
2020
)
21+
from pytorch3d.renderer.compositing import alpha_composite, norm_weighted_sum
2122
from pytorch3d.renderer.points import (
2223
AlphaCompositor,
2324
NormWeightedCompositor,
@@ -171,3 +172,54 @@ def test_simple_sphere_batched(self):
171172
DATA_DIR / filename
172173
)
173174
self.assertClose(rgb, image_ref)
175+
176+
def test_compositor_background_color(self):
177+
178+
N, H, W, K, C, P = 1, 15, 15, 20, 4, 225
179+
ptclds = torch.randn((C, P))
180+
alphas = torch.rand((N, K, H, W))
181+
pix_idxs = torch.randint(-1, 20, (N, K, H, W)) # 20 < P, large amount of -1
182+
background_color = [0.5, 0, 1]
183+
184+
compositor_funcs = [
185+
(NormWeightedCompositor, norm_weighted_sum),
186+
(AlphaCompositor, alpha_composite),
187+
]
188+
189+
for (compositor_class, composite_func) in compositor_funcs:
190+
191+
compositor = compositor_class(background_color)
192+
193+
# run the forward method to generate masked images
194+
masked_images = compositor.forward(pix_idxs, alphas, ptclds)
195+
196+
# generate unmasked images for testing purposes
197+
images = composite_func(pix_idxs, alphas, ptclds)
198+
199+
is_foreground = pix_idxs[:, 0] >= 0
200+
201+
# make sure foreground values are unchanged
202+
self.assertClose(
203+
torch.masked_select(masked_images, is_foreground[:, None]),
204+
torch.masked_select(images, is_foreground[:, None]),
205+
)
206+
207+
is_background = ~is_foreground[..., None].expand(-1, -1, -1, 4)
208+
209+
# permute masked_images to correctly get rgb values
210+
masked_images = masked_images.permute(0, 2, 3, 1)
211+
for i in range(3):
212+
channel_color = background_color[i]
213+
214+
# check if background colors are properly changed
215+
self.assertTrue(
216+
masked_images[is_background]
217+
.view(-1, 4)[..., i]
218+
.eq(channel_color)
219+
.all()
220+
)
221+
222+
# check background color alpha values
223+
self.assertTrue(
224+
masked_images[is_background].view(-1, 4)[..., 3].eq(1).all()
225+
)

0 commit comments

Comments
 (0)