Skip to content

Commit 358e211

Browse files
megluyagaofacebook-github-bot
authored andcommitted
Adding renderer for ShapeNetBase
Summary: Adding a renderer to ShapeNetCore (Note that the lights are currently turned off for the test; will investigate why lighting causes instability in rendering) Reviewed By: nikhilaravi Differential Revision: D22102673 fbshipit-source-id: a704756a1e93b61d5a879f0e5ee14ebcb0df49d7
1 parent 09c1762 commit 358e211

File tree

5 files changed

+194
-43
lines changed

5 files changed

+194
-43
lines changed

pytorch3d/datasets/shapenet/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
23
from .shapenet_core import ShapeNetCore
34

45

pytorch3d/datasets/shapenet/shapenet_core.py

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import warnings
66
from os import path
77
from pathlib import Path
8+
from typing import Dict
89

9-
import torch
10+
from pytorch3d.datasets.shapenet_base import ShapeNetBase
1011
from pytorch3d.io import load_obj
1112

1213

1314
SYNSET_DICT_DIR = Path(__file__).resolve().parent
1415

1516

16-
class ShapeNetCore(torch.utils.data.Dataset):
17+
class ShapeNetCore(ShapeNetBase):
1718
"""
1819
This class loads ShapeNetCore from a given directory into a Dataset object.
1920
ShapeNetCore is a subset of the ShapeNet dataset and can be downloaded from
@@ -23,6 +24,7 @@ class ShapeNetCore(torch.utils.data.Dataset):
2324
def __init__(self, data_dir, synsets=None, version: int = 1):
2425
"""
2526
Store each object's synset id and models id from data_dir.
27+
2628
Args:
2729
data_dir: Path to ShapeNetCore data.
2830
synsets: List of synset categories to load from ShapeNetCore in the form of
@@ -38,6 +40,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
3840
version 1.
3941
4042
"""
43+
super().__init__()
4144
self.data_dir = data_dir
4245
if version not in [1, 2]:
4346
raise ValueError("Version number must be either 1 or 2.")
@@ -48,7 +51,7 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
4851
with open(path.join(SYNSET_DICT_DIR, dict_file), "r") as read_dict:
4952
self.synset_dict = json.load(read_dict)
5053
# Inverse dicitonary mapping synset labels to corresponding offsets.
51-
synset_inv = {label: offset for offset, label in self.synset_dict.items()}
54+
self.synset_inv = {label: offset for offset, label in self.synset_dict.items()}
5255

5356
# If categories are specified, check if each category is in the form of either
5457
# synset offset or synset label, and if the category exists in the given directory.
@@ -60,62 +63,61 @@ def __init__(self, data_dir, synsets=None, version: int = 1):
6063
path.isdir(path.join(data_dir, synset))
6164
):
6265
synset_set.add(synset)
63-
elif (synset in synset_inv.keys()) and (
64-
(path.isdir(path.join(data_dir, synset_inv[synset])))
66+
elif (synset in self.synset_inv.keys()) and (
67+
(path.isdir(path.join(data_dir, self.synset_inv[synset])))
6568
):
66-
synset_set.add(synset_inv[synset])
69+
synset_set.add(self.synset_inv[synset])
6770
else:
68-
msg = """Synset category %s either not part of ShapeNetCore dataset
69-
or cannot be found in %s.""" % (
70-
synset,
71-
data_dir,
72-
)
71+
msg = (
72+
"Synset category %s either not part of ShapeNetCore dataset "
73+
"or cannot be found in %s."
74+
) % (synset, data_dir)
7375
warnings.warn(msg)
7476
# If no category is given, load every category in the given directory.
77+
# Ignore synset folders not included in the official mapping.
7578
else:
7679
synset_set = {
7780
synset
7881
for synset in os.listdir(data_dir)
7982
if path.isdir(path.join(data_dir, synset))
83+
and synset in self.synset_dict
8084
}
81-
for synset in synset_set:
82-
if synset not in self.synset_dict.keys():
83-
msg = """Synset category %s(%s) is part of ShapeNetCore ver.%s
84-
but not found in %s.""" % (
85-
synset,
86-
self.synset_dict[synset],
87-
version,
88-
data_dir,
89-
)
90-
warnings.warn(msg)
85+
86+
# Check if there are any categories in the official mapping that are not loaded.
87+
# Update self.synset_inv so that it only includes the loaded categories.
88+
synset_not_present = set(self.synset_dict.keys()).difference(synset_set)
89+
[self.synset_inv.pop(self.synset_dict[synset]) for synset in synset_not_present]
90+
91+
if len(synset_not_present) > 0:
92+
msg = (
93+
"The following categories are included in ShapeNetCore ver.%d's "
94+
"official mapping but not found in the dataset location %s: %s"
95+
""
96+
) % (version, data_dir, ", ".join(synset_not_present))
97+
warnings.warn(msg)
9198

9299
# Extract model_id of each object from directory names.
93100
# Each grandchildren directory of data_dir contains an object, and the name
94101
# of the directory is the object's model_id.
95-
self.synset_ids = []
96-
self.model_ids = []
97102
for synset in synset_set:
98103
for model in os.listdir(path.join(data_dir, synset)):
99104
if not path.exists(path.join(data_dir, synset, model, self.model_dir)):
100-
msg = """ Object file not found in the model directory %s
101-
under synset directory %s.""" % (
102-
model,
103-
synset,
104-
)
105+
msg = (
106+
"Object file not found in the model directory %s "
107+
"under synset directory %s."
108+
) % (model, synset)
105109
warnings.warn(msg)
106-
else:
107-
self.synset_ids.append(synset)
108-
self.model_ids.append(model)
110+
continue
111+
self.synset_ids.append(synset)
112+
self.model_ids.append(model)
109113

110-
def __len__(self):
111-
"""
112-
Return number of total models in shapenet core.
113-
"""
114-
return len(self.model_ids)
115-
116-
def __getitem__(self, idx):
114+
def __getitem__(self, idx: int) -> Dict:
117115
"""
118116
Read a model by the given index.
117+
118+
Args:
119+
idx: The idx of the model to be retrieved in the dataset.
120+
119121
Returns:
120122
dictionary with following keys:
121123
- verts: FloatTensor of shape (V, 3).
@@ -124,9 +126,7 @@ def __getitem__(self, idx):
124126
- model_id (str): model id
125127
- label (str): synset label.
126128
"""
127-
model = {}
128-
model["synset_id"] = self.synset_ids[idx]
129-
model["model_id"] = self.model_ids[idx]
129+
model = self._get_item_ids(idx)
130130
model_path = path.join(
131131
self.data_dir, model["synset_id"], model["model_id"], self.model_dir
132132
)

pytorch3d/datasets/shapenet_base.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from typing import Dict
4+
5+
import torch
6+
from pytorch3d.renderer import (
7+
HardPhongShader,
8+
MeshRasterizer,
9+
MeshRenderer,
10+
OpenGLPerspectiveCameras,
11+
PointLights,
12+
RasterizationSettings,
13+
)
14+
from pytorch3d.structures import Meshes, Textures
15+
16+
17+
class ShapeNetBase(torch.utils.data.Dataset):
18+
"""
19+
'ShapeNetBase' implements a base Dataset for ShapeNet and R2N2 with helper methods.
20+
It is not intended to be used on its own as a Dataset for a Dataloader. Both __init__
21+
and __getitem__ need to be implemented.
22+
"""
23+
24+
def __init__(self):
25+
"""
26+
Set up lists of synset_ids and model_ids.
27+
"""
28+
self.synset_ids = []
29+
self.model_ids = []
30+
31+
def __len__(self):
32+
"""
33+
Return number of total models in the loaded dataset.
34+
"""
35+
return len(self.model_ids)
36+
37+
def __getitem__(self, idx) -> Dict:
38+
"""
39+
Read a model by the given index. Need to be implemented for every child class
40+
of ShapeNetBase.
41+
42+
Args:
43+
idx: The idx of the model to be retrieved in the dataset.
44+
45+
Returns:
46+
dictionary containing information about the model.
47+
"""
48+
raise NotImplementedError(
49+
"__getitem__ should be implemented in the child class of ShapeNetBase"
50+
)
51+
52+
def _get_item_ids(self, idx) -> Dict:
53+
"""
54+
Read a model by the given index.
55+
56+
Args:
57+
idx: The idx of the model to be retrieved in the dataset.
58+
59+
Returns:
60+
dictionary with following keys:
61+
- synset_id (str): synset id
62+
- model_id (str): model id
63+
"""
64+
model = {}
65+
model["synset_id"] = self.synset_ids[idx]
66+
model["model_id"] = self.model_ids[idx]
67+
return model
68+
69+
def render(
70+
self, idx: int = 0, shader_type=HardPhongShader, device="cpu", **kwargs
71+
) -> torch.Tensor:
72+
"""
73+
Renders a model by the given index.
74+
75+
Args:
76+
idx: The index of model to be rendered in the dataset.
77+
shader_type: select shading. Valid options include HardPhongShader (default),
78+
SoftPhongShader, HardGouraudShader, SoftGouraudShader, HardFlatShader,
79+
SoftSilhouetteShader.
80+
device: torch.device on which the tensors should be located.
81+
**kwargs: Accepts any of the kwargs that the renderer supports.
82+
83+
Returns:
84+
Rendered image of shape (1, H, W, 3).
85+
"""
86+
87+
model = self.__getitem__(idx)
88+
verts, faces = model["verts"], model["faces"]
89+
verts_rgb = torch.ones_like(verts, device=device)[None]
90+
mesh = Meshes(
91+
verts=[verts.to(device)],
92+
faces=[faces.to(device)],
93+
textures=Textures(verts_rgb=verts_rgb.to(device)),
94+
)
95+
cameras = kwargs.get("cameras", OpenGLPerspectiveCameras()).to(device)
96+
renderer = MeshRenderer(
97+
rasterizer=MeshRasterizer(
98+
cameras=cameras,
99+
raster_settings=kwargs.get("raster_settings", RasterizationSettings()),
100+
),
101+
shader=shader_type(
102+
device=device,
103+
cameras=cameras,
104+
lights=kwargs.get("lights", PointLights()).to(device),
105+
),
106+
)
107+
return renderer(mesh)
3.19 KB
Loading

tests/test_shapenet_core.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,32 @@
66
import random
77
import unittest
88
import warnings
9+
from pathlib import Path
910

11+
import numpy as np
1012
import torch
11-
from common_testing import TestCaseMixin
13+
from common_testing import TestCaseMixin, load_rgb_image
14+
from PIL import Image
1215
from pytorch3d.datasets import ShapeNetCore
16+
from pytorch3d.renderer import (
17+
OpenGLPerspectiveCameras,
18+
PointLights,
19+
RasterizationSettings,
20+
look_at_view_transform,
21+
)
1322

1423

1524
SHAPENET_PATH = None
25+
# If DEBUG=True, save out images generated in the tests for debugging.
26+
# All saved images have prefix DEBUG_
27+
DEBUG = False
28+
DATA_DIR = Path(__file__).resolve().parent / "data"
1629

1730

1831
class TestShapenetCore(TestCaseMixin, unittest.TestCase):
1932
def test_load_shapenet_core(self):
33+
# Setup
34+
device = torch.device("cuda:0")
2035

2136
# The ShapeNet dataset is not provided in the repo.
2237
# Download this separately and update the `shapenet_path`
@@ -31,7 +46,7 @@ def test_load_shapenet_core(self):
3146
warnings.warn(msg)
3247
return True
3348

34-
# Try load ShapeNetCore with an invalid version number and catch error.
49+
# Try loading ShapeNetCore with an invalid version number and catch error.
3550
with self.assertRaises(ValueError) as err:
3651
ShapeNetCore(SHAPENET_PATH, version=3)
3752
self.assertTrue("Version number must be either 1 or 2." in str(err.exception))
@@ -93,3 +108,31 @@ def test_load_shapenet_core(self):
93108
for offset in subset_offsets
94109
]
95110
self.assertEqual(len(shapenet_subset), sum(subset_model_nums))
111+
112+
# Render the first image in the piano category.
113+
R, T = look_at_view_transform(1.0, 1.0, 90)
114+
piano_dataset = ShapeNetCore(SHAPENET_PATH, synsets=["piano"])
115+
116+
cameras = OpenGLPerspectiveCameras(R=R, T=T, device=device)
117+
raster_settings = RasterizationSettings(image_size=512)
118+
lights = PointLights(
119+
location=torch.tensor([0.0, 1.0, -2.0], device=device)[None],
120+
# TODO: debug the source of the discrepancy in two images when rendering on GPU.
121+
diffuse_color=((0, 0, 0),),
122+
specular_color=((0, 0, 0),),
123+
device=device,
124+
)
125+
images = piano_dataset.render(
126+
0,
127+
device=device,
128+
cameras=cameras,
129+
raster_settings=raster_settings,
130+
lights=lights,
131+
)
132+
rgb = images[0, ..., :3].squeeze().cpu()
133+
if DEBUG:
134+
Image.fromarray((rgb.numpy() * 255).astype(np.uint8)).save(
135+
DATA_DIR / "DEBUG_shapenet_core_render_piano.png"
136+
)
137+
image_ref = load_rgb_image("test_shapenet_core_render_piano.png", DATA_DIR)
138+
self.assertClose(rgb, image_ref, atol=0.05)

0 commit comments

Comments
 (0)