Skip to content

Commit f8d14d2

Browse files
committed
renderer surface norm init
add normal rerun viz
1 parent 77bd39e commit f8d14d2

File tree

6 files changed

+274
-2
lines changed

6 files changed

+274
-2
lines changed

b3d/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -200,4 +200,4 @@ def rerun_visualize_trace_t(trace, t, modes=["rgb", "depth", "inliers"]):
200200
pose.apply(vertices),
201201
colors=(attributes * 255).astype(jnp.uint8),
202202
),
203-
)
203+
)

b3d/pose.py

+47
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,53 @@ def camera_from_position_and_target(
9090
rotation_matrix = jnp.hstack([x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)])
9191
return Pose(position, Rot.from_matrix(rotation_matrix).as_quat())
9292

93+
def rotation_from_axis_angle(axis, angle):
94+
"""Creates a rotation matrix from an axis and angle.
95+
96+
Args:
97+
axis (jnp.ndarray): The axis vector. Shape (3,)
98+
angle (float): The angle in radians.
99+
Returns:
100+
jnp.ndarray: The rotation matrix. Shape (3, 3)
101+
"""
102+
sina = jnp.sin(angle)
103+
cosa = jnp.cos(angle)
104+
direction = axis / jnp.linalg.norm(axis)
105+
# rotation matrix around unit vector
106+
R = jnp.diag(jnp.array([cosa, cosa, cosa]))
107+
R = R + jnp.outer(direction, direction) * (1.0 - cosa)
108+
direction = direction * sina
109+
R = R + jnp.array(
110+
[
111+
[0.0, -direction[2], direction[1]],
112+
[direction[2], 0.0, -direction[0]],
113+
[-direction[1], direction[0], 0.0],
114+
]
115+
)
116+
return R
117+
118+
def from_rot(rotation):
119+
"""Creates a pose matrix from a rotation matrix.
120+
121+
Args:
122+
rotation (jnp.ndarray): The rotation matrix. Shape (3, 3)
123+
Returns:
124+
Pose object
125+
"""
126+
return Pose.from_matrix(jnp.vstack(
127+
[jnp.hstack([rotation, jnp.zeros((3, 1))]), jnp.array([0.0, 0.0, 0.0, 1.0])]
128+
))
129+
130+
def from_axis_angle(axis, angle):
131+
"""Creates a pose matrix from an axis and angle.
132+
133+
Args:
134+
axis (jnp.ndarray): The axis vector. Shape (3,)
135+
angle (float): The angle in radians.
136+
Returns:
137+
Pose object
138+
"""
139+
return from_rot(rotation_from_axis_angle(axis, angle))
93140

94141
@register_pytree_node_class
95142
class Pose:

b3d/renderer.py

+89
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,95 @@ def render_attribute(self, pose, vertices, faces, ranges, attributes):
276276
return image[0], zs[0]
277277

278278

279+
def render_attribute_normal_many(self, poses, vertices, faces, ranges, attributes):
280+
"""
281+
Render many scenes to an image by rasterizing and then interpolating attributes.
282+
283+
Parameters:
284+
poses: float array, shape (num_scenes, num_objectsß, 4, 4)
285+
Object pose matrix.
286+
vertices: float array, shape (num_vertices, 3)
287+
Vertex position matrix.
288+
faces: int array, shape (num_triangles, 3)
289+
Faces Triangle matrix. The integers ßcorrespond to rows in the vertices matrix.
290+
ranges: int array, shape (num_objects, 2)
291+
Ranges matrix with the 2 elements specify start indices and counts into faces.
292+
attributes: float array, shape (num_vertices, num_attributes)
293+
Attributes corresponding to the vertices
294+
295+
Outputs:
296+
image: float array, shape (num_scenes, height, width, num_attributes)
297+
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
298+
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
299+
any triangle the value at that pixel will be 0s.
300+
zs: float array, shape (num_scenes, height, width)
301+
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
302+
norm_im: approximate surface normal image (num_scenes, height, width, 3)
303+
"""
304+
uvs, object_ids, triangle_ids, zs = self.rasterize_many(
305+
poses, vertices, faces, ranges
306+
)
307+
mask = object_ids > 0
308+
309+
interpolated_values = self.interpolate_many(
310+
attributes, uvs, triangle_ids, faces
311+
)
312+
image = interpolated_values * mask[..., None]
313+
314+
def apply_pose(pose, points):
315+
return pose.apply(points)
316+
317+
pose_apply_map = jax.vmap(apply_pose, (0,None))
318+
new_vertices = pose_apply_map(poses, vertices[faces])
319+
320+
def normal_vec(x,y,z):
321+
vec = jnp.cross(y - x, z - x)
322+
norm_vec = vec / jnp.linalg.norm(vec)
323+
return norm_vec
324+
325+
normal_vec_vmap = jax.vmap(jax.vmap(normal_vec, (0,0,0)))
326+
nvecs = normal_vec_vmap(new_vertices[...,0,:], new_vertices[...,1,:], new_vertices[...,2,:])
327+
norm_vecs = jnp.concatenate((jnp.zeros((len(nvecs),1,3)), nvecs),axis=1)
328+
329+
def indexer(transformed_normals, triangle_ids):
330+
return transformed_normals[triangle_ids]
331+
332+
index_map = jax.vmap(indexer, (0,0))
333+
norm_im = index_map(norm_vecs, triangle_ids)
334+
335+
return image, zs, norm_im
336+
337+
def render_attribute_normal(self, pose, vertices, faces, ranges, attributes):
338+
"""
339+
Render a single scenes to an image by rasterizing and then interpolating attributes.
340+
341+
Parameters:
342+
poses: float array, shape (num_objects, 4, 4)
343+
Object pose matrix.
344+
vertices: float array, shape (num_vertices, 3)
345+
Vertex position matrix.
346+
faces: int array, shape (num_triangles, 3)
347+
Faces Triangle matrix. The integers correspond to rows in the vertices matrix.
348+
ranges: int array, shape (num_objects, 2)
349+
Ranges matrix with the 2 elements specify start indices and counts into faces.
350+
attributes: float array, shape (num_vertices, num_attributes)
351+
Attributes corresponding to the vertices
352+
353+
Outputs:
354+
image: float array, shape (height, width, num_attributes)
355+
At each pixel the value is the barycentric interpolation of the attributes corresponding to the
356+
3 vertices of the triangle with which the pixel's ray intersected. If the pixel's ray does not intersect
357+
any triangle the value at that pixel will be 0s.
358+
zs: float array, shape (height, width)
359+
Depth of the intersection point. Zero if the pixel ray doesn't collide a triangle.
360+
norm_im: approximate surface normal image (height, width, 3)
361+
"""
362+
image, zs, norm_im = self.render_attribute_normal_many(
363+
pose[None, ...], vertices, faces, ranges, attributes
364+
)
365+
return image[0], zs[0], norm_im[0]
366+
367+
279368
# XLA array layout in memory
280369
def default_layouts(*shapes):
281370
return [range(len(shape) - 1, -1, -1) for shape in shapes]

b3d/utils.py

+16
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ def update_choices_get_score(trace, key, addr_const, *values):
395395
enumerate_choices_get_scores, static_argnums=(2,)
396396
)
397397

398+
def unproject_depth(depth, renderer):
399+
"""Unprojects a depth image into a point cloud.
400+
401+
Args:
402+
depth (jnp.ndarray): The depth image. Shape (H, W)
403+
intrinsics (b.camera.Intrinsics): The camera intrinsics.
404+
Returns:
405+
jnp.ndarray: The point cloud. Shape (H, W, 3)
406+
"""
407+
mask = (depth < renderer.far) * (depth > renderer.near)
408+
depth = depth * mask + renderer.far * (1.0 - mask)
409+
y, x = jnp.mgrid[: depth.shape[0], : depth.shape[1]]
410+
x = (x - renderer.cx) / renderer.fx
411+
y = (y - renderer.cy) / renderer.fy
412+
point_cloud_image = jnp.stack([x, y, jnp.ones_like(x)], axis=-1) * depth[:, :, None]
413+
return point_cloud_image
398414

399415
def nn_background_segmentation(images):
400416
import torch

test/test_likelihood_invariances.py

+86
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,89 @@ def test_distance_to_camera_invarance(renderer):
169169

170170
assert jnp.isclose(near_score, far_score, rtol=0.03)
171171

172+
def test_patch_orientation_invariance(renderer):
173+
174+
object_library = b3d.MeshLibrary.make_empty_library()
175+
occluder = trimesh.creation.box(extents=jnp.array([0.0001, 0.1, 0.1]))
176+
occluder_colors = jnp.tile(jnp.array([0.8, 0.8, 0.8])[None,...], (occluder.vertices.shape[0], 1))
177+
object_library = b3d.MeshLibrary.make_empty_library()
178+
object_library.add_object(occluder.vertices, occluder.faces, attributes=occluder_colors)
179+
180+
image_width = 200
181+
image_height = 200
182+
fx = 200.0
183+
fy = 200.0
184+
cx = 100.0
185+
cy = 100.0
186+
near = 0.001
187+
far = 16.0
188+
renderer.set_intrinsics(image_width, image_height, fx, fy, cx, cy, near, far)
189+
190+
flat_pose = b3d.Pose.from_position_and_target(
191+
jnp.array([0.3, 0.0, 0.0]), jnp.array([0.0, 0.0, 0.0]), jnp.array([0.0, 0.0, 1.0])
192+
).inv()
193+
194+
from b3d.pose import from_axis_angle
195+
196+
transform_vec = jax.vmap(from_axis_angle, (None, 0))
197+
in_place_rots = transform_vec(jnp.array([0,0,1]), jnp.linspace(0, jnp.pi/4, 10))
198+
tilt_pose = flat_pose @ in_place_rots[5]
199+
200+
rgb_flat, depth_flat = renderer.render_attribute(
201+
flat_pose[None, ...],
202+
object_library.vertices,
203+
object_library.faces,
204+
object_library.ranges,
205+
object_library.attributes,
206+
)
207+
208+
rgb_tilt, depth_tilt = renderer.render_attribute(
209+
tilt_pose[None, ...],
210+
object_library.vertices,
211+
object_library.faces,
212+
object_library.ranges,
213+
object_library.attributes,
214+
)
215+
216+
217+
color_error, depth_error = (50.0, 0.01)
218+
inlier_score, outlier_prob = (4.0, 0.000001)
219+
color_multiplier, depth_multiplier = (100.0, 1.0)
220+
model_args = b3d.ModelArgs(
221+
color_error,
222+
depth_error,
223+
inlier_score,
224+
outlier_prob,
225+
color_multiplier,
226+
depth_multiplier,
227+
)
228+
229+
from genjax.generative_functions.distributions import ExactDensity
230+
import genjax
231+
232+
233+
rr.log("img_near", rr.Image(rgb_flat))
234+
rr.log("img_far", rr.Image(rgb_tilt))
235+
236+
237+
238+
area_flat = ((depth_flat / fx) * (depth_flat / fy)).sum()
239+
area_tilt = ((depth_tilt / fx) * (depth_tilt / fy)).sum()
240+
print(area_flat, area_tilt)
241+
242+
flat_score = (
243+
b3d.rgbd_sensor_model.logpdf(
244+
(rgb_flat, depth_flat), rgb_flat, depth_flat, model_args, fx, fy, 0.0
245+
)
246+
)
247+
248+
tilt_score = (
249+
b3d.rgbd_sensor_model.logpdf(
250+
(rgb_tilt, depth_tilt), rgb_tilt, depth_tilt, model_args, fx, fy, 0.0
251+
)
252+
)
253+
print(flat_score, tilt_score)
254+
print(b3d.normalize_log_scores(jnp.array([flat_score, tilt_score])))
255+
256+
assert jnp.isclose(flat_score, tilt_score, rtol=0.05)
257+

test/test_render_ycb_model.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import jax.numpy as jnp
33
import trimesh
44
import b3d
5+
import rerun as rr
56

7+
PORT = 8812
8+
rr.init("real")
9+
rr.connect(addr=f"127.0.0.1:{PORT}")
610

711
def test_renderer_full(renderer):
812
mesh_path = os.path.join(
@@ -15,7 +19,7 @@ def test_renderer_full(renderer):
1519
object_library.add_trimesh(mesh)
1620

1721
pose = b3d.Pose.from_position_and_target(
18-
jnp.array([0.2, 0.2, 0.0]), jnp.array([0.0, 0.0, 0.0])
22+
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
1923
).inv()
2024

2125
rgb, depth = renderer.render_attribute(
@@ -27,3 +31,33 @@ def test_renderer_full(renderer):
2731
)
2832
b3d.get_rgb_pil_image(rgb).save(b3d.get_root_path() / "assets/test_results/test_ycb.png")
2933
assert rgb.sum() > 0
34+
35+
def test_renderer_normal_full(renderer):
36+
mesh_path = os.path.join(
37+
b3d.get_root_path(),
38+
"assets/shared_data_bucket/ycb_video_models/models/003_cracker_box/textured_simple.obj",
39+
)
40+
mesh = trimesh.load(mesh_path)
41+
42+
object_library = b3d.MeshLibrary.make_empty_library()
43+
object_library.add_trimesh(mesh)
44+
45+
pose = b3d.Pose.from_position_and_target(
46+
jnp.array([0.2, 0.2, 0.2]), jnp.array([0.0, 0.0, 0.0])
47+
).inv()
48+
49+
rgb, depth, normal = renderer.render_attribute_normal(
50+
pose[None, ...],
51+
object_library.vertices,
52+
object_library.faces,
53+
jnp.array([[0, len(object_library.faces)]]),
54+
object_library.attributes,
55+
)
56+
57+
b3d.get_rgb_pil_image((normal+1)/2).save(b3d.get_root_path() / "assets/test_results/test_ycb_normal.png")
58+
59+
point_im = b3d.utils.unproject_depth(depth, renderer)
60+
rr.log("pc", rr.Points3D(point_im.reshape(-1,3), colors=rgb.reshape(-1,3)))
61+
rr.log("arrows", rr.Arrows3D(origins=point_im[::5,::5,:].reshape(-1,3), vectors=normal[::5,::5,:].reshape(-1,3)/100))
62+
63+
assert jnp.abs(normal).sum() > 0

0 commit comments

Comments
 (0)