Skip to content

Commit 9a50cf8

Browse files
bottlerfacebook-github-bot
authored andcommitted
Fix batching bug from TexturesUV packed ambiguity, other textures tidyup
Summary: faces_uvs_packed and verts_uvs_packed were only used in one place and the definition of the former was ambiguous. This meant that the wrong coordinates could be used for meshes other than the first in the batch. I have therefore removed both functions and build their common result inline. Added a test that a simple batch of two meshes is rendered consistently with the rendering of each alone. This test would have failed before. I hope this fixes facebookresearch#283. Some other small improvements to the textures code. Reviewed By: nikhilaravi Differential Revision: D23161936 fbshipit-source-id: f99b560a46f6b30262e07028b049812bc04350a7
1 parent 9aaba04 commit 9a50cf8

File tree

7 files changed

+106
-65
lines changed

7 files changed

+106
-65
lines changed

pytorch3d/csrc/utils/pytorch3d_cutils.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
#pragma once
44
#include <torch/extension.h>
55

6-
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x "must be a CUDA tensor.")
6+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.")
77
#define CHECK_CONTIGUOUS(x) \
8-
TORCH_CHECK(x.is_contiguous(), #x "must be contiguous.")
8+
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
99
#define CHECK_CONTIGUOUS_CUDA(x) \
1010
CHECK_CUDA(x); \
1111
CHECK_CONTIGUOUS(x)

pytorch3d/datasets/shapenet_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
import torch
88
from pytorch3d.io import load_objs_as_meshes
99
from pytorch3d.renderer import (
10+
FoVPerspectiveCameras,
1011
HardPhongShader,
1112
MeshRasterizer,
1213
MeshRenderer,
13-
FoVPerspectiveCameras,
1414
PointLights,
1515
RasterizationSettings,
1616
TexturesVertex,

pytorch3d/renderer/blending.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def hard_rgb_blend(colors, fragments, blend_params) -> torch.Tensor:
4545
# Mask for the background.
4646
is_background = fragments.pix_to_face[..., 0] < 0 # (N, H, W)
4747

48-
background_color = colors.new_tensor(blend_params.background_color) # (3)
48+
if torch.is_tensor(blend_params.background_color):
49+
background_color = blend_params.background_color
50+
else:
51+
background_color = colors.new_tensor(blend_params.background_color) # (3)
4952

5053
# Find out how much background_color needs to be expanded to be used for masked_scatter.
5154
num_background_pixels = is_background.sum()

pytorch3d/renderer/mesh/textures.py

+13-41
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _pad_texture_maps(
137137
# This is also useful to have so that inside `Meshes`
138138
# we can allow the input textures to be any texture
139139
# type which is an instance of the base class.
140-
class TexturesBase(object):
140+
class TexturesBase:
141141
def __init__(self):
142142
self._N = 0
143143
self.valid = None
@@ -262,9 +262,6 @@ class attributes for item i. Then, a new
262262
"""
263263
raise NotImplementedError()
264264

265-
def __repr__(self):
266-
return "TexturesBase"
267-
268265

269266
def Textures(
270267
maps: Union[List, torch.Tensor, None] = None,
@@ -385,14 +382,6 @@ def __init__(self, atlas: Union[torch.Tensor, List, None]):
385382
# refer to the __init__ of Meshes.
386383
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
387384

388-
# This is a hack to allow the child classes to also have the same representation
389-
# as the parent. In meshes.py we check that the input textures have the correct
390-
# type. However due to circular imports issues, we can't import the texture
391-
# classes into any files in pytorch3d.structures. Instead we check
392-
# for repr(textures) == "TexturesBase".
393-
def __repr__(self):
394-
return super().__repr__()
395-
396385
def clone(self):
397386
tex = self.__class__(atlas=self.atlas_padded().clone())
398387
if self._atlas_list is not None:
@@ -556,10 +545,7 @@ def __init__(
556545
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
557546
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
558547
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
559-
560-
Note: only the padded and list representation of the textures are stored
561-
and the packed representations is computed on the fly and
562-
not cached.
548+
(a FloatTensor with values between 0 and 1)
563549
"""
564550
super().__init__()
565551
if isinstance(faces_uvs, (list, tuple)):
@@ -611,9 +597,6 @@ def __init__(
611597
"verts_uvs and faces_uvs must have the same batch dimension"
612598
)
613599
if not all(v.device == self.device for v in verts_uvs):
614-
import pdb
615-
616-
pdb.set_trace()
617600
raise ValueError("verts_uvs and faces_uvs must be on the same device")
618601

619602
# These values may be overridden when textures is
@@ -669,9 +652,6 @@ def __init__(
669652

670653
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
671654

672-
def __repr__(self):
673-
return super().__repr__()
674-
675655
def clone(self):
676656
tex = self.__class__(
677657
self.maps_padded().clone(),
@@ -759,12 +739,6 @@ def faces_uvs_list(self) -> List[torch.Tensor]:
759739
)
760740
return self._faces_uvs_list
761741

762-
def faces_uvs_packed(self) -> torch.Tensor:
763-
if self.isempty():
764-
return torch.zeros((self._N, 3), dtype=torch.float32, device=self.device)
765-
faces_uvs_list = self.faces_uvs_list()
766-
return list_to_packed(faces_uvs_list)[0]
767-
768742
def verts_uvs_padded(self) -> torch.Tensor:
769743
if self._verts_uvs_padded is None:
770744
if self.isempty():
@@ -789,12 +763,6 @@ def verts_uvs_list(self) -> List[torch.Tensor]:
789763
)
790764
return self._verts_uvs_list
791765

792-
def verts_uvs_packed(self) -> torch.Tensor:
793-
if self.isempty():
794-
return torch.zeros((self._N, 2), dtype=torch.float32, device=self.device)
795-
verts_uvs_list = self.verts_uvs_list()
796-
return list_to_packed(verts_uvs_list)[0]
797-
798766
# Currently only the padded maps are used.
799767
def maps_padded(self) -> torch.Tensor:
800768
return self._maps_padded
@@ -850,9 +818,15 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
850818
texels: tensor of shape (N, H, W, K, C) giving the interpolated
851819
texture for each pixel in the rasterized image.
852820
"""
853-
verts_uvs = self.verts_uvs_packed()
854-
faces_uvs = self.faces_uvs_packed()
855-
faces_verts_uvs = verts_uvs[faces_uvs]
821+
if self.isempty():
822+
faces_verts_uvs = torch.zeros(
823+
(self._N, 3, 2), dtype=torch.float32, device=self.device
824+
)
825+
else:
826+
packing_list = [
827+
i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list())
828+
]
829+
faces_verts_uvs = torch.cat(packing_list)
856830
texture_maps = self.maps_padded()
857831

858832
# pixel_uvs: (N, H, W, K, 2)
@@ -890,6 +864,7 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
890864
if texture_maps.device != pixel_uvs.device:
891865
texture_maps = texture_maps.to(pixel_uvs.device)
892866
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
867+
# texels now has shape (NK, C, H_out, W_out)
893868
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
894869
return texels
895870

@@ -990,9 +965,6 @@ def __init__(
990965
# refer to the __init__ of Meshes.
991966
self.valid = torch.ones((self._N,), dtype=torch.bool, device=self.device)
992967

993-
def __repr__(self):
994-
return super().__repr__()
995-
996968
def clone(self):
997969
tex = self.__class__(self.verts_features_padded().clone())
998970
if self._verts_features_list is not None:
@@ -1048,7 +1020,7 @@ def verts_features_list(self) -> List[torch.Tensor]:
10481020
if self._verts_features_list is None:
10491021
if self.isempty():
10501022
self._verts_features_list = [
1051-
torch.empty((0, 3, 0), dtype=torch.float32, device=self.device)
1023+
torch.empty((0, 3), dtype=torch.float32, device=self.device)
10521024
] * self._N
10531025
else:
10541026
self._verts_features_list = padded_to_list(

pytorch3d/structures/meshes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ def __init__(self, verts=None, faces=None, textures=None):
233233
Refer to comments above for descriptions of List and Padded representations.
234234
"""
235235
self.device = None
236-
if textures is not None and not repr(textures) == "TexturesBase":
236+
if textures is not None and not hasattr(textures, "sample_textures"):
237237
msg = "Expected textures to be an instance of type TexturesBase; got %r"
238-
raise ValueError(msg % repr(textures))
238+
raise ValueError(msg % type(textures))
239239
self.textures = textures
240240

241241
# Indicates whether the meshes in the list/batch have the same number

tests/test_render_meshes.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,9 @@
3333
SoftSilhouetteShader,
3434
TexturedSoftPhongShader,
3535
)
36-
from pytorch3d.structures.meshes import Meshes, join_mesh
36+
from pytorch3d.structures.meshes import Meshes, join_mesh, join_meshes_as_batch
3737
from pytorch3d.utils.ico_sphere import ico_sphere
38+
from pytorch3d.utils.torus import torus
3839

3940

4041
# If DEBUG=True, save out images generated in the tests for debugging.
@@ -490,6 +491,86 @@ def test_texture_map(self):
490491

491492
self.assertClose(rgb, image_ref, atol=0.05)
492493

494+
def test_batch_uvs(self):
495+
"""Test that two random tori with TexturesUV render the same as each individually."""
496+
torch.manual_seed(1)
497+
device = torch.device("cuda:0")
498+
plain_torus = torus(r=1, R=4, sides=10, rings=10, device=device)
499+
[verts] = plain_torus.verts_list()
500+
[faces] = plain_torus.faces_list()
501+
nocolor = torch.zeros((100, 100), device=device)
502+
color_gradient = torch.linspace(0, 1, steps=100, device=device)
503+
color_gradient1 = color_gradient[None].expand_as(nocolor)
504+
color_gradient2 = color_gradient[:, None].expand_as(nocolor)
505+
colors1 = torch.stack([nocolor, color_gradient1, color_gradient2], dim=2)
506+
colors2 = torch.stack([color_gradient1, color_gradient2, nocolor], dim=2)
507+
verts_uvs1 = torch.rand(size=(verts.shape[0], 2), device=device)
508+
verts_uvs2 = torch.rand(size=(verts.shape[0], 2), device=device)
509+
510+
textures1 = TexturesUV(
511+
maps=[colors1], faces_uvs=[faces], verts_uvs=[verts_uvs1]
512+
)
513+
textures2 = TexturesUV(
514+
maps=[colors2], faces_uvs=[faces], verts_uvs=[verts_uvs2]
515+
)
516+
mesh1 = Meshes(verts=[verts], faces=[faces], textures=textures1)
517+
mesh2 = Meshes(verts=[verts], faces=[faces], textures=textures2)
518+
mesh_both = join_meshes_as_batch([mesh1, mesh2])
519+
520+
R, T = look_at_view_transform(10, 10, 0)
521+
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
522+
523+
raster_settings = RasterizationSettings(
524+
image_size=128, blur_radius=0.0, faces_per_pixel=1
525+
)
526+
527+
# Init shader settings
528+
lights = PointLights(device=device)
529+
lights.location = torch.tensor([0.0, 0.0, 2.0], device=device)[None]
530+
531+
blend_params = BlendParams(
532+
sigma=1e-1,
533+
gamma=1e-4,
534+
background_color=torch.tensor([1.0, 1.0, 1.0], device=device),
535+
)
536+
# Init renderer
537+
renderer = MeshRenderer(
538+
rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings),
539+
shader=HardPhongShader(
540+
device=device, lights=lights, cameras=cameras, blend_params=blend_params
541+
),
542+
)
543+
544+
outputs = []
545+
for meshes in [mesh_both, mesh1, mesh2]:
546+
outputs.append(renderer(meshes))
547+
548+
if DEBUG:
549+
Image.fromarray(
550+
(outputs[0][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
551+
).save(DATA_DIR / "test_batch_uvs0.png")
552+
Image.fromarray(
553+
(outputs[1][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
554+
).save(DATA_DIR / "test_batch_uvs1.png")
555+
Image.fromarray(
556+
(outputs[0][1, ..., :3].cpu().numpy() * 255).astype(np.uint8)
557+
).save(DATA_DIR / "test_batch_uvs2.png")
558+
Image.fromarray(
559+
(outputs[2][0, ..., :3].cpu().numpy() * 255).astype(np.uint8)
560+
).save(DATA_DIR / "test_batch_uvs3.png")
561+
562+
diff = torch.abs(outputs[0][0, ..., :3] - outputs[1][0, ..., :3])
563+
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
564+
DATA_DIR / "test_batch_uvs01.png"
565+
)
566+
diff = torch.abs(outputs[0][1, ..., :3] - outputs[2][0, ..., :3])
567+
Image.fromarray(((diff > 1e-5).cpu().numpy().astype(np.uint8) * 255)).save(
568+
DATA_DIR / "test_batch_uvs23.png"
569+
)
570+
571+
self.assertClose(outputs[0][0, ..., :3], outputs[1][0, ..., :3], atol=1e-5)
572+
self.assertClose(outputs[0][1, ..., :3], outputs[2][0, ..., :3], atol=1e-5)
573+
493574
def test_joined_spheres(self):
494575
"""
495576
Test a list of Meshes can be joined as a single mesh and

tests/test_texturing.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def tryindex(self, index, tex, meshes, source):
2929
basic = basic[None]
3030

3131
if len(basic) == 0:
32-
self.assertEquals(len(from_texture), 0)
33-
self.assertEquals(len(from_meshes), 0)
32+
self.assertEqual(len(from_texture), 0)
33+
self.assertEqual(len(from_meshes), 0)
3434
else:
3535
self.assertClose(basic, from_texture)
3636
self.assertClose(basic, from_meshes)
@@ -608,12 +608,8 @@ def test_extend(self):
608608
[
609609
tex_init.faces_uvs_padded(),
610610
new_tex.faces_uvs_padded(),
611-
tex_init.faces_uvs_packed(),
612-
new_tex.faces_uvs_packed(),
613611
tex_init.verts_uvs_padded(),
614612
new_tex.verts_uvs_padded(),
615-
tex_init.verts_uvs_packed(),
616-
new_tex.verts_uvs_packed(),
617613
tex_init.maps_padded(),
618614
new_tex.maps_padded(),
619615
]
@@ -646,11 +642,9 @@ def test_padded_to_packed(self):
646642
tex1 = tex.clone()
647643
tex1._num_faces_per_mesh = num_faces_per_mesh
648644
tex1._num_verts_per_mesh = num_verts_per_mesh
649-
verts_packed = tex1.verts_uvs_packed()
650645
verts_list = tex1.verts_uvs_list()
651646
verts_padded = tex1.verts_uvs_padded()
652647

653-
faces_packed = tex1.faces_uvs_packed()
654648
faces_list = tex1.faces_uvs_list()
655649
faces_padded = tex1.faces_uvs_padded()
656650

@@ -660,9 +654,7 @@ def test_padded_to_packed(self):
660654
for f1, f2 in zip(verts_list, verts_uvs_list):
661655
self.assertTrue((f1 == f2).all().item())
662656

663-
self.assertTrue(faces_packed.shape == (3 + 2, 3))
664657
self.assertTrue(faces_padded.shape == (2, 3, 3))
665-
self.assertTrue(verts_packed.shape == (9 + 6, 2))
666658
self.assertTrue(verts_padded.shape == (2, 9, 2))
667659

668660
# Case where num_faces_per_mesh is not set and faces_verts_uvs
@@ -672,16 +664,9 @@ def test_padded_to_packed(self):
672664
verts_uvs=verts_padded,
673665
faces_uvs=faces_padded,
674666
)
675-
faces_packed = tex2.faces_uvs_packed()
676667
faces_list = tex2.faces_uvs_list()
677-
verts_packed = tex2.verts_uvs_packed()
678668
verts_list = tex2.verts_uvs_list()
679669

680-
# Packed is just flattened padded as num_faces_per_mesh
681-
# has not been provided.
682-
self.assertTrue(faces_packed.shape == (3 * 2, 3))
683-
self.assertTrue(verts_packed.shape == (9 * 2, 2))
684-
685670
for i, (f1, f2) in enumerate(zip(faces_list, faces_uvs_list)):
686671
n = num_faces_per_mesh[i]
687672
self.assertTrue((f1[:n] == f2).all().item())

0 commit comments

Comments
 (0)