Skip to content

Commit

Permalink
align_corners and padding for TexturesUV
Browse files Browse the repository at this point in the history
Summary:
Allow, and make default, align_corners=True for texture maps. Allow changing the padding_mode and set the default to be "border" which produces more logical results. Some new documentation.

The previous behavior corresponds to padding_mode="zeros" and align_corners=False.

Reviewed By: gkioxari

Differential Revision: D23268775

fbshipit-source-id: 58d6229baa591baa69705bcf97471c80ba3651de
  • Loading branch information
bottler authored and facebook-github-bot committed Aug 25, 2020
1 parent d0cec02 commit e25ccab
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 32 deletions.
84 changes: 74 additions & 10 deletions pytorch3d/renderer/mesh/textures.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,14 @@ def _padded_to_list_wrapper(


def _pad_texture_maps(
images: Union[Tuple[torch.Tensor], List[torch.Tensor]]
images: Union[Tuple[torch.Tensor], List[torch.Tensor]], align_corners: bool
) -> torch.Tensor:
"""
Pad all texture images so they have the same height and width.
Args:
images: list of N tensors of shape (H, W, 3)
images: list of N tensors of shape (H_i, W_i, 3)
align_corners: used for interpolation
Returns:
tex_maps: Tensor of shape (N, max_H, max_W, 3)
Expand All @@ -125,7 +126,7 @@ def _pad_texture_maps(
if image.shape[:2] != max_shape:
image_BCHW = image.permute(2, 0, 1)[None]
new_image_BCHW = interpolate(
image_BCHW, size=max_shape, mode="bilinear", align_corners=False
image_BCHW, size=max_shape, mode="bilinear", align_corners=align_corners
)
tex_maps[i] = new_image_BCHW[0].permute(1, 2, 0)
tex_maps = torch.stack(tex_maps, dim=0) # (num_tex_maps, max_H, max_W, 3)
Expand Down Expand Up @@ -535,6 +536,8 @@ def __init__(
maps: Union[torch.Tensor, List[torch.Tensor]],
faces_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
verts_uvs: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]],
padding_mode: str = "border",
align_corners: bool = True,
):
"""
Textures are represented as a per mesh texture map and uv coordinates for each
Expand All @@ -543,11 +546,42 @@ def __init__(
Args:
maps: texture map per mesh. This can either be a list of maps
[(H, W, 3)] or a padded tensor of shape (N, H, W, 3)
faces_uvs: (N, F, 3) tensor giving the index into verts_uvs for each face
faces_uvs: (N, F, 3) LongTensor giving the index into verts_uvs
for each face
verts_uvs: (N, V, 2) tensor giving the uv coordinates per vertex
(a FloatTensor with values between 0 and 1)
(a FloatTensor with values between 0 and 1).
align_corners: If true, the extreme values 0 and 1 for verts_uvs
indicate the centers of the edge pixels in the maps.
padding_mode: padding mode for outside grid values
("zeros", "border" or "reflection").
The align_corners and padding_mode arguments correspond to the arguments
of the `grid_sample` torch function. There is an informative illustration of
the two align_corners options at
https://discuss.pytorch.org/t/22663/9 .
An example of how the indexing into the maps, with align_corners=True
works is as follows.
If maps[i] has shape [101, 1001] and the value of verts_uvs[i][j]
is [0.4, 0.3], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40]. padding_mode affects what
happens if a value in verts_uvs is less than 0 or greater than 1.
Note that increasing a value in verts_uvs[..., 0] increases an index
in maps, whereas increasing a value in verts_uvs[..., 1] _decreases_
an _earlier_ index in maps.
If align_corners=False, an example would be as follows.
If maps[i] has shape [100, 1000] and the value of verts_uvs[i][j]
is [0.405, 0.2995], then a value of j in faces_uvs[i] means a vertex
whose color is given by maps[i][700, 40].
In this case, padding_mode even matters for values in verts_uvs
slightly above 0 or slightly below 1. In this case, it matters if the
first value is outside the interval [0.0005, 0.9995] or if the second
is outside the interval [0.005, 0.995].
"""
super().__init__()
self.padding_mode = padding_mode
self.align_corners = align_corners
if isinstance(faces_uvs, (list, tuple)):
for fv in faces_uvs:
# pyre-fixme[16]: `Tensor` has no attribute `ndim`.
Expand Down Expand Up @@ -632,7 +666,7 @@ def __init__(
raise ValueError("Expected one texture map per mesh in the batch.")
self._maps_list = maps
if self._N > 0:
maps = _pad_texture_maps(maps)
maps = _pad_texture_maps(maps, align_corners=self.align_corners)
else:
maps = torch.empty(
(self._N, 0, 0, 3), dtype=torch.float32, device=self.device
Expand Down Expand Up @@ -698,11 +732,19 @@ def __getitem__(self, index):
# if index has multiple values then faces/verts/maps may be a list of tensors
if all(isinstance(f, (list, tuple)) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__(
faces_uvs=faces_uvs, verts_uvs=verts_uvs, maps=maps
faces_uvs=faces_uvs,
verts_uvs=verts_uvs,
maps=maps,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)
elif all(torch.is_tensor(f) for f in [faces_uvs, verts_uvs, maps]):
new_tex = self.__class__(
faces_uvs=[faces_uvs], verts_uvs=[verts_uvs], maps=[maps]
faces_uvs=[faces_uvs],
verts_uvs=[verts_uvs],
maps=[maps],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)
else:
raise ValueError("Not all values are provided in the correct format")
Expand Down Expand Up @@ -785,6 +827,8 @@ def extend(self, N: int) -> "TexturesUV":
maps=new_props["maps_padded"],
faces_uvs=new_props["faces_uvs_padded"],
verts_uvs=new_props["verts_uvs_padded"],
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)

new_tex._num_faces_per_mesh = new_props["_num_faces_per_mesh"]
Expand Down Expand Up @@ -859,7 +903,12 @@ def sample_textures(self, fragments, **kwargs) -> torch.Tensor:
texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map
if texture_maps.device != pixel_uvs.device:
texture_maps = texture_maps.to(pixel_uvs.device)
texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False)
texels = F.grid_sample(
texture_maps,
pixel_uvs,
align_corners=self.align_corners,
padding_mode=self.padding_mode,
)
# texels now has shape (NK, C, H_out, W_out)
texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2)
return texels
Expand All @@ -881,6 +930,17 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
if not tex_types_same:
raise ValueError("All textures must be of type TexturesUV.")

padding_modes_same = all(
tex.padding_mode == self.padding_mode for tex in textures
)
if not padding_modes_same:
raise ValueError("All textures must have the same padding_mode.")
align_corners_same = all(
tex.align_corners == self.align_corners for tex in textures
)
if not align_corners_same:
raise ValueError("All textures must have the same align_corners value.")

verts_uvs_list = []
faces_uvs_list = []
maps_list = []
Expand All @@ -896,7 +956,11 @@ def join_batch(self, textures: List["TexturesUV"]) -> "TexturesUV":
maps_list += tex_map_list

new_tex = self.__class__(
maps=maps_list, verts_uvs=verts_uvs_list, faces_uvs=faces_uvs_list
maps=maps_list,
verts_uvs=verts_uvs_list,
faces_uvs=faces_uvs_list,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
)
new_tex._num_faces_per_mesh = num_faces_per_mesh
return new_tex
Expand Down
6 changes: 5 additions & 1 deletion pytorch3d/structures/meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,9 +1505,13 @@ def join_meshes_as_batch(meshes: List[Meshes], include_textures: bool = True):
Merge multiple Meshes objects, i.e. concatenate the meshes objects. They
must all be on the same device. If include_textures is true, they must all
be compatible, either all or none having textures, and all the Textures
objects having the same members. If include_textures is False, textures are
objects being the same type. If include_textures is False, textures are
ignored.
If the textures are TexturesAtlas then being the same type includes having
the same resolution. If they are TexturesUV then it includes having the same
align_corners and padding_mode.
Args:
meshes: list of meshes.
include_textures: (bool) whether to try to join the textures.
Expand Down
Binary file modified tests/data/test_blurry_textured_rendering.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/test_texture_map_back.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/data/test_texture_map_front.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/test_meshes.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ def test_detach(self):

self.assertFalse(new_mesh.verts_packed().requires_grad)
self.assertClose(new_mesh.verts_packed(), mesh.verts_packed())
self.assertTrue(new_mesh.verts_padded().requires_grad == False)
self.assertFalse(new_mesh.verts_padded().requires_grad)
self.assertClose(new_mesh.verts_padded(), mesh.verts_padded())
for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()):
self.assertTrue(newv.requires_grad == False)
self.assertFalse(newv.requires_grad)
self.assertClose(newv, v)

def test_laplacian_packed(self):
Expand Down
10 changes: 4 additions & 6 deletions tests/test_pointclouds.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,11 @@ def test_detach(self):
new_clouds = clouds.detach()

for cloud in new_clouds.points_list():
self.assertTrue(cloud.requires_grad == False)
self.assertFalse(cloud.requires_grad)
for normal in new_clouds.normals_list():
self.assertTrue(normal.requires_grad == False)
self.assertFalse(normal.requires_grad)
for feats in new_clouds.features_list():
self.assertTrue(feats.requires_grad == False)
self.assertFalse(feats.requires_grad)

for attrib in [
"points_packed",
Expand All @@ -425,9 +425,7 @@ def test_detach(self):
"normals_padded",
"features_padded",
]:
self.assertTrue(
getattr(new_clouds, attrib)().requires_grad == False
)
self.assertFalse(getattr(new_clouds, attrib)().requires_grad)

self.assertCloudsEqual(clouds, new_clouds)

Expand Down
33 changes: 20 additions & 13 deletions tests/test_texturing.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,19 +443,26 @@ def test_sample_textures_uv(self):
dists=pix_to_face,
)

tex = TexturesUV(maps=tex_map, faces_uvs=[face_uvs], verts_uvs=[vert_uvs])
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)

# Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0
pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
tex_map = torch.flip(tex_map, [1])
tex_map = tex_map.permute(0, 3, 1, 2)
tex_map = torch.cat([tex_map, tex_map], dim=0)
expected_out = F.grid_sample(tex_map, pixel_uvs, align_corners=False)
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))
for align_corners in [True, False]:
tex = TexturesUV(
maps=tex_map,
faces_uvs=[face_uvs],
verts_uvs=[vert_uvs],
align_corners=align_corners,
)
meshes = Meshes(verts=[dummy_verts], faces=[face_uvs], textures=tex)
mesh_textures = meshes.textures
texels = mesh_textures.sample_textures(fragments)

# Expected output
pixel_uvs = interpolated_uvs * 2.0 - 1.0
pixel_uvs = pixel_uvs.view(2, 1, 1, 2)
tex_map_ = torch.flip(tex_map, [1]).permute(0, 3, 1, 2)
tex_map_ = torch.cat([tex_map_, tex_map_], dim=0)
expected_out = F.grid_sample(
tex_map_, pixel_uvs, align_corners=align_corners, padding_mode="border"
)
self.assertTrue(torch.allclose(texels.squeeze(), expected_out.squeeze()))

def test_textures_uv_init_fail(self):
# Maps has wrong shape
Expand Down

0 comments on commit e25ccab

Please sign in to comment.