Skip to content

Commit 659ad34

Browse files
gkioxarifacebook-github-bot
authored andcommitted
load texture flag
Summary: Add flag for loading textures Reviewed By: nikhilaravi Differential Revision: D19664437 fbshipit-source-id: 3cc4e6179df9b7e24efff9e7da3b164253f1d775
1 parent 244b7eb commit 659ad34

File tree

2 files changed

+67
-12
lines changed

2 files changed

+67
-12
lines changed

pytorch3d/io/obj_io.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def _open_file(f):
8888
return f, new_f
8989

9090

91-
def load_obj(f_obj):
91+
def load_obj(f_obj, load_textures=True):
9292
"""
9393
Load a mesh and textures from a .obj and .mtl file.
9494
Currently this handles verts, faces, vertex texture uv coordinates, normals,
@@ -146,6 +146,7 @@ def load_obj(f_obj):
146146
Args:
147147
f: A file-like object (with methods read, readline, tell, and seek),
148148
a pathlib path or a string containing a file name.
149+
load_textures: Boolean indicating whether material files are loaded
149150
150151
Returns:
151152
6-element tuple containing
@@ -201,7 +202,7 @@ def load_obj(f_obj):
201202
data_dir = os.path.dirname(f_obj)
202203
f_obj, new_f = _open_file(f_obj)
203204
try:
204-
return _load(f_obj, data_dir)
205+
return _load(f_obj, data_dir, load_textures=load_textures)
205206
finally:
206207
if new_f:
207208
f_obj.close()
@@ -273,7 +274,7 @@ def _parse_face(
273274
faces_materials_idx.append(material_idx)
274275

275276

276-
def _load(f_obj, data_dir):
277+
def _load(f_obj, data_dir, load_textures=True):
277278
"""
278279
Load a mesh from a file-like object. See load_obj function more details.
279280
Any material files associated with the obj are expected to be in the
@@ -362,15 +363,16 @@ def _load(f_obj, data_dir):
362363

363364
# Load materials
364365
material_colors, texture_images = None, None
365-
if (len(material_names) > 0) and (f_mtl is not None):
366-
if os.path.isfile(f_mtl):
367-
material_colors, texture_images = load_mtl(
368-
f_mtl, material_names, data_dir
369-
)
370-
else:
371-
warnings.warn(f"Mtl file does not exist: {f_mtl}")
372-
elif len(material_names) > 0:
373-
warnings.warn("No mtl file provided")
366+
if load_textures:
367+
if (len(material_names) > 0) and (f_mtl is not None):
368+
if os.path.isfile(f_mtl):
369+
material_colors, texture_images = load_mtl(
370+
f_mtl, material_names, data_dir
371+
)
372+
else:
373+
warnings.warn(f"Mtl file does not exist: {f_mtl}")
374+
elif len(material_names) > 0:
375+
warnings.warn("No mtl file provided")
374376

375377
faces = _Faces(
376378
verts_idx=faces_verts_idx,

tests/test_obj_io.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,17 @@ def test_load_mtl(self):
390390
)
391391
)
392392

393+
def test_load_mtl_noload(self):
394+
DATA_DIR = (
395+
Path(__file__).resolve().parent.parent / "docs/tutorials/data"
396+
)
397+
obj_filename = "cow_mesh/cow.obj"
398+
filename = os.path.join(DATA_DIR, obj_filename)
399+
verts, faces, aux = load_obj(filename, load_textures=False)
400+
401+
self.assertTrue(aux.material_colors is None)
402+
self.assertTrue(aux.texture_images is None)
403+
393404
def test_load_mtl_fail(self):
394405
# Faces have a material
395406
obj_file = "\n".join(
@@ -444,6 +455,27 @@ def test_load_obj_missing_texture(self):
444455
self.assertTrue(torch.allclose(verts, expected_verts))
445456
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
446457

458+
def test_load_obj_missing_texture_noload(self):
459+
DATA_DIR = Path(__file__).resolve().parent / "data"
460+
obj_filename = "missing_files_obj/model.obj"
461+
filename = os.path.join(DATA_DIR, obj_filename)
462+
verts, faces, aux = load_obj(filename, load_textures=False)
463+
464+
expected_verts = torch.tensor(
465+
[
466+
[0.1, 0.2, 0.3],
467+
[0.2, 0.3, 0.4],
468+
[0.3, 0.4, 0.5],
469+
[0.4, 0.5, 0.6],
470+
],
471+
dtype=torch.float32,
472+
)
473+
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
474+
self.assertTrue(torch.allclose(verts, expected_verts))
475+
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
476+
self.assertTrue(aux.material_colors is None)
477+
self.assertTrue(aux.texture_images is None)
478+
447479
def test_load_obj_missing_mtl(self):
448480
DATA_DIR = Path(__file__).resolve().parent / "data"
449481
obj_filename = "missing_files_obj/model2.obj"
@@ -464,6 +496,27 @@ def test_load_obj_missing_mtl(self):
464496
self.assertTrue(torch.allclose(verts, expected_verts))
465497
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
466498

499+
def test_load_obj_missing_mtl_noload(self):
500+
DATA_DIR = Path(__file__).resolve().parent / "data"
501+
obj_filename = "missing_files_obj/model2.obj"
502+
filename = os.path.join(DATA_DIR, obj_filename)
503+
verts, faces, aux = load_obj(filename, load_textures=False)
504+
505+
expected_verts = torch.tensor(
506+
[
507+
[0.1, 0.2, 0.3],
508+
[0.2, 0.3, 0.4],
509+
[0.3, 0.4, 0.5],
510+
[0.4, 0.5, 0.6],
511+
],
512+
dtype=torch.float32,
513+
)
514+
expected_faces = torch.tensor([[0, 1, 2], [0, 1, 3]], dtype=torch.int64)
515+
self.assertTrue(torch.allclose(verts, expected_verts))
516+
self.assertTrue(torch.allclose(faces.verts_idx, expected_faces))
517+
self.assertTrue(aux.material_colors is None)
518+
self.assertTrue(aux.texture_images is None)
519+
467520
@staticmethod
468521
def save_obj_with_init(V: int, F: int):
469522
verts_list = torch.tensor(V * [[0.11, 0.22, 0.33]]).view(-1, 3)

0 commit comments

Comments
 (0)