1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
+ import warnings
4
+ from typing import List , Optional , Tuple , Union
5
+
3
6
import torch
4
7
import torch .nn as nn
5
8
@@ -16,11 +19,20 @@ class AlphaCompositor(nn.Module):
16
19
Accumulate points using alpha compositing.
17
20
"""
18
21
19
- def __init__ (self ):
22
+ def __init__ (
23
+ self , background_color : Optional [Union [Tuple , List , torch .Tensor ]] = None
24
+ ):
20
25
super ().__init__ ()
26
+ self .background_color = background_color
21
27
22
28
def forward (self , fragments , alphas , ptclds , ** kwargs ) -> torch .Tensor :
29
+ background_color = kwargs .get ("background_color" , self .background_color )
23
30
images = alpha_composite (fragments , alphas , ptclds )
31
+
32
+ # images are of shape (N, C, H, W)
33
+ # check for background color & feature size C (C=4 indicates rgba)
34
+ if background_color is not None and images .shape [1 ] == 4 :
35
+ return _add_background_color_to_images (fragments , images , background_color )
24
36
return images
25
37
26
38
@@ -29,9 +41,68 @@ class NormWeightedCompositor(nn.Module):
29
41
Accumulate points using a normalized weighted sum.
30
42
"""
31
43
32
- def __init__ (self ):
44
+ def __init__ (
45
+ self , background_color : Optional [Union [Tuple , List , torch .Tensor ]] = None
46
+ ):
33
47
super ().__init__ ()
48
+ self .background_color = background_color
34
49
35
50
def forward (self , fragments , alphas , ptclds , ** kwargs ) -> torch .Tensor :
51
+ background_color = kwargs .get ("background_color" , self .background_color )
36
52
images = norm_weighted_sum (fragments , alphas , ptclds )
53
+
54
+ # images are of shape (N, C, H, W)
55
+ # check for background color & feature size C (C=4 indicates rgba)
56
+ if background_color is not None and images .shape [1 ] == 4 :
57
+ return _add_background_color_to_images (fragments , images , background_color )
37
58
return images
59
+
60
+
61
+ def _add_background_color_to_images (pix_idxs , images , background_color ):
62
+ """
63
+ Mask pixels in images without corresponding points with a given background_color.
64
+
65
+ Args:
66
+ pix_idxs: int32 Tensor of shape (N, points_per_pixel, image_size, image_size)
67
+ giving the indices of the nearest points at each pixel, sorted in z-order.
68
+ images: Tensor of shape (N, 4, image_size, image_size) giving the
69
+ accumulated features at each point, where 4 refers to a rgba feature.
70
+ background_color: Tensor, list, or tuple with 3 or 4 values indicating the rgb/rgba
71
+ value for the new background. Values should be in the interval [0,1].
72
+ Returns:
73
+ images: Tensor of shape (N, 4, image_size, image_size), where pixels with
74
+ no nearest points have features set to the background color, and other
75
+ pixels with accumulated features have unchanged values.
76
+ """
77
+ # Initialize background mask
78
+ background_mask = pix_idxs [:, 0 ] < 0 # (N, image_size, image_size)
79
+
80
+ # Convert background_color to an appropriate tensor and check shape
81
+ if not torch .is_tensor (background_color ):
82
+ background_color = images .new_tensor (background_color )
83
+
84
+ background_shape = background_color .shape
85
+
86
+ if len (background_shape ) != 1 or background_shape [0 ] not in (3 , 4 ):
87
+ warnings .warn (
88
+ "Background color should be size (3) or (4), but is size %s instead"
89
+ % (background_shape ,)
90
+ )
91
+ return images
92
+
93
+ background_color = background_color .to (images )
94
+
95
+ # add alpha channel
96
+ if background_shape [0 ] == 3 :
97
+ alpha = images .new_ones (1 )
98
+ background_color = torch .cat ([background_color , alpha ])
99
+
100
+ num_background_pixels = background_mask .sum ()
101
+
102
+ # permute so that features are the last dimension for masked_scatter to work
103
+ masked_images = images .permute (0 , 2 , 3 , 1 )[..., :4 ].masked_scatter (
104
+ background_mask [..., None ],
105
+ background_color [None , :].expand (num_background_pixels , - 1 ),
106
+ )
107
+
108
+ return masked_images .permute (0 , 3 , 1 , 2 )
0 commit comments