@@ -39,13 +39,16 @@ class HardPhongShader(nn.Module):
39
39
shader = HardPhongShader(device=torch.device("cuda:0"))
40
40
"""
41
41
42
- def __init__ (self , device = "cpu" , cameras = None , lights = None , materials = None ):
42
+ def __init__ (
43
+ self , device = "cpu" , cameras = None , lights = None , materials = None , blend_params = None
44
+ ):
43
45
super ().__init__ ()
44
46
self .lights = lights if lights is not None else PointLights (device = device )
45
47
self .materials = (
46
48
materials if materials is not None else Materials (device = device )
47
49
)
48
50
self .cameras = cameras
51
+ self .blend_params = blend_params if blend_params is not None else BlendParams ()
49
52
50
53
def forward (self , fragments , meshes , ** kwargs ) -> torch .Tensor :
51
54
cameras = kwargs .get ("cameras" , self .cameras )
@@ -57,6 +60,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
57
60
texels = interpolate_vertex_colors (fragments , meshes )
58
61
lights = kwargs .get ("lights" , self .lights )
59
62
materials = kwargs .get ("materials" , self .materials )
63
+ blend_params = kwargs .get ("blend_params" , self .blend_params )
60
64
colors = phong_shading (
61
65
meshes = meshes ,
62
66
fragments = fragments ,
@@ -65,7 +69,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
65
69
cameras = cameras ,
66
70
materials = materials ,
67
71
)
68
- images = hard_rgb_blend (colors , fragments )
72
+ images = hard_rgb_blend (colors , fragments , blend_params )
69
73
return images
70
74
71
75
@@ -130,13 +134,16 @@ class HardGouraudShader(nn.Module):
130
134
shader = HardGouraudShader(device=torch.device("cuda:0"))
131
135
"""
132
136
133
- def __init__ (self , device = "cpu" , cameras = None , lights = None , materials = None ):
137
+ def __init__ (
138
+ self , device = "cpu" , cameras = None , lights = None , materials = None , blend_params = None
139
+ ):
134
140
super ().__init__ ()
135
141
self .lights = lights if lights is not None else PointLights (device = device )
136
142
self .materials = (
137
143
materials if materials is not None else Materials (device = device )
138
144
)
139
145
self .cameras = cameras
146
+ self .blend_params = blend_params if blend_params is not None else BlendParams ()
140
147
141
148
def forward (self , fragments , meshes , ** kwargs ) -> torch .Tensor :
142
149
cameras = kwargs .get ("cameras" , self .cameras )
@@ -146,14 +153,15 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
146
153
raise ValueError (msg )
147
154
lights = kwargs .get ("lights" , self .lights )
148
155
materials = kwargs .get ("materials" , self .materials )
156
+ blend_params = kwargs .get ("blend_params" , self .blend_params )
149
157
pixel_colors = gouraud_shading (
150
158
meshes = meshes ,
151
159
fragments = fragments ,
152
160
lights = lights ,
153
161
cameras = cameras ,
154
162
materials = materials ,
155
163
)
156
- images = hard_rgb_blend (pixel_colors , fragments )
164
+ images = hard_rgb_blend (pixel_colors , fragments , blend_params )
157
165
return images
158
166
159
167
@@ -266,13 +274,16 @@ class HardFlatShader(nn.Module):
266
274
shader = HardFlatShader(device=torch.device("cuda:0"))
267
275
"""
268
276
269
- def __init__ (self , device = "cpu" , cameras = None , lights = None , materials = None ):
277
+ def __init__ (
278
+ self , device = "cpu" , cameras = None , lights = None , materials = None , blend_params = None
279
+ ):
270
280
super ().__init__ ()
271
281
self .lights = lights if lights is not None else PointLights (device = device )
272
282
self .materials = (
273
283
materials if materials is not None else Materials (device = device )
274
284
)
275
285
self .cameras = cameras
286
+ self .blend_params = blend_params if blend_params is not None else BlendParams ()
276
287
277
288
def forward (self , fragments , meshes , ** kwargs ) -> torch .Tensor :
278
289
cameras = kwargs .get ("cameras" , self .cameras )
@@ -283,6 +294,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
283
294
texels = interpolate_vertex_colors (fragments , meshes )
284
295
lights = kwargs .get ("lights" , self .lights )
285
296
materials = kwargs .get ("materials" , self .materials )
297
+ blend_params = kwargs .get ("blend_params" , self .blend_params )
286
298
colors = flat_shading (
287
299
meshes = meshes ,
288
300
fragments = fragments ,
@@ -291,7 +303,7 @@ def forward(self, fragments, meshes, **kwargs) -> torch.Tensor:
291
303
cameras = cameras ,
292
304
materials = materials ,
293
305
)
294
- images = hard_rgb_blend (colors , fragments )
306
+ images = hard_rgb_blend (colors , fragments , blend_params )
295
307
return images
296
308
297
309
0 commit comments