Skip to content

Commit

Permalink
Save UV texture with obj mesh
Browse files Browse the repository at this point in the history
Summary: Add functionality to to save an `.obj` file with associated UV textures: `.png` image and `.mtl` file as well as saving verts_uvs and faces_uvs to the `.obj` file.

Reviewed By: bottler

Differential Revision: D29337562

fbshipit-source-id: 86829b40dae9224088b328e7f5a16eacf8582eb5
  • Loading branch information
nikhilaravi authored and facebook-github-bot committed Jun 24, 2021
1 parent 64289a4 commit 542e2e7
Show file tree
Hide file tree
Showing 3 changed files with 549 additions and 233 deletions.
128 changes: 116 additions & 12 deletions pytorch3d/io/obj_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import torch
from iopath.common.file_io import PathManager
from PIL import Image
from pytorch3d.common.types import Device
from pytorch3d.io.mtl_io import load_mtl, make_mesh_texture_atlas
from pytorch3d.io.utils import _check_faces_indices, _make_tensor, _open_file
Expand Down Expand Up @@ -649,42 +650,118 @@ def _load_obj(


def save_obj(
f,
f: Union[str, os.PathLike],
verts,
faces,
decimal_places: Optional[int] = None,
path_manager: Optional[PathManager] = None,
):
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
texture_map: Optional[torch.Tensor] = None,
) -> None:
"""
Save a mesh to an .obj file.
Args:
f: File (or path) to which the mesh should be written.
f: File (str or path) to which the mesh should be written.
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
faces: LongTensor of shape (F, 3) giving faces.
decimal_places: Number of decimal places for saving.
path_manager: Optional PathManager for interpreting f if
it is a str.
verts_uvs: FloatTensor of shape (V, 2) giving the uv coordinate per vertex.
faces_uvs: LongTensor of shape (F, 3) giving the index into verts_uvs for
each vertex in the face.
texture_map: FloatTensor of shape (H, W, 3) representing the texture map
for the mesh which will be saved as an image. The values are expected
to be in the range [0, 1],
"""
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
message = "Argument 'verts' should either be empty or of shape (num_verts, 3)."
if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)

if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)

if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)

if len(faces) and not (faces.dim() == 2 and faces.size(1) == 3):
message = "Argument 'faces' should either be empty or of shape (num_faces, 3)."
if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)

if texture_map is not None and (texture_map.dim() != 3 or texture_map.size(2) != 3):
message = "'texture_map' should either be empty or of shape (H, W, 3)."
raise ValueError(message)

if path_manager is None:
path_manager = PathManager()

save_texture = all([t is not None for t in [faces_uvs, verts_uvs, texture_map]])
output_path = Path(f)

# Save the .obj file
with _open_file(f, path_manager, "w") as f:
return _save(f, verts, faces, decimal_places)
if save_texture:
# Add the header required for the texture info to be loaded correctly
obj_header = "\nmtllib {0}.mtl\nusemtl mesh\n\n".format(output_path.stem)
f.write(obj_header)
_save(
f,
verts,
faces,
decimal_places,
verts_uvs=verts_uvs,
faces_uvs=faces_uvs,
save_texture=save_texture,
)

# Save the .mtl and .png files associated with the texture
if save_texture:
image_path = output_path.with_suffix(".png")
mtl_path = output_path.with_suffix(".mtl")
if isinstance(f, str):
# Back to str for iopath interpretation.
image_path = str(image_path)
mtl_path = str(mtl_path)

# Save texture map to output folder
# pyre-fixme[16] # undefined attribute cpu
texture_map = texture_map.detach().cpu() * 255.0
image = Image.fromarray(texture_map.numpy().astype(np.uint8))
with _open_file(image_path, path_manager, "wb") as im_f:
# pyre-fixme[6] # incompatible parameter type
image.save(im_f)

# Create .mtl file with the material name and texture map filename
# TODO: enable material properties to also be saved.
with _open_file(mtl_path, path_manager, "w") as f_mtl:
lines = f"newmtl mesh\n" f"map_Kd {output_path.stem}.png\n"
f_mtl.write(lines)


# TODO (nikhilar) Speed up this function.
def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
assert not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
def _save(
f,
verts,
faces,
decimal_places: Optional[int] = None,
*,
verts_uvs: Optional[torch.Tensor] = None,
faces_uvs: Optional[torch.Tensor] = None,
save_texture: bool = False,
) -> None:

if len(verts) and (verts.dim() != 2 or verts.size(1) != 3):
message = "'verts' should either be empty or of shape (num_verts, 3)."
raise ValueError(message)

if len(faces) and (faces.dim() != 2 or faces.size(1) != 3):
message = "'faces' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)

if not (len(verts) or len(faces)):
warnings.warn("Empty 'verts' and 'faces' arguments provided")
Expand All @@ -705,15 +782,42 @@ def _save(f, verts, faces, decimal_places: Optional[int] = None) -> None:
vert = [float_str % verts[i, j] for j in range(D)]
lines += "v %s\n" % " ".join(vert)

if save_texture:
if faces_uvs is not None and (faces_uvs.dim() != 2 or faces_uvs.size(1) != 3):
message = "'faces_uvs' should either be empty or of shape (num_faces, 3)."
raise ValueError(message)

if verts_uvs is not None and (verts_uvs.dim() != 2 or verts_uvs.size(1) != 2):
message = "'verts_uvs' should either be empty or of shape (num_verts, 2)."
raise ValueError(message)

# pyre-fixme[16] # undefined attribute cpu
verts_uvs, faces_uvs = verts_uvs.cpu(), faces_uvs.cpu()

# Save verts uvs after verts
if len(verts_uvs):
uV, uD = verts_uvs.shape
for i in range(uV):
uv = [float_str % verts_uvs[i, j] for j in range(uD)]
lines += "vt %s\n" % " ".join(uv)

if torch.any(faces >= verts.shape[0]) or torch.any(faces < 0):
warnings.warn("Faces have invalid indices")

if len(faces):
F, P = faces.shape
for i in range(F):
face = ["%d" % (faces[i, j] + 1) for j in range(P)]
if save_texture:
# Format faces as {verts_idx}/{verts_uvs_idx}
face = [
"%d/%d" % (faces[i, j] + 1, faces_uvs[i, j] + 1) for j in range(P)
]
else:
face = ["%d" % (faces[i, j] + 1) for j in range(P)]

if i + 1 < F:
lines += "f %s\n" % " ".join(face)

elif i + 1 == F:
# No newline at the end of the file.
lines += "f %s" % " ".join(face)
Expand Down
2 changes: 1 addition & 1 deletion tests/common_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_pytorch3d_dir() -> Path:


def load_rgb_image(filename: str, data_dir: Union[str, Path]):
filepath = data_dir / filename
filepath = os.path.join(data_dir, filename)
with Image.open(filepath) as raw_image:
image = torch.from_numpy(np.array(raw_image) / 255.0)
image = image.to(dtype=torch.float32)
Expand Down
Loading

0 comments on commit 542e2e7

Please sign in to comment.