-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathvis_utils.py
77 lines (65 loc) · 2.58 KB
/
vis_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# copy from nerfstudio and 2DGS
import torch
from matplotlib import cm
import open3d as o3d
import matplotlib.pyplot as plt
import numpy as np
def apply_colormap(image, cmap="viridis"):
colormap = cm.get_cmap(cmap)
colormap = torch.tensor(colormap.colors).to(image.device) # type: ignore
image_long = (image * 255).long()
image_long_min = torch.min(image_long)
image_long_max = torch.max(image_long)
assert image_long_min >= 0, f"the min value is {image_long_min}"
assert image_long_max <= 255, f"the max value is {image_long_max}"
return colormap[image_long[..., 0]]
def apply_depth_colormap(
depth,
accumulation,
near_plane = 2.0,
far_plane = 6.0,
cmap="turbo",
):
near_plane = near_plane or float(torch.min(depth))
far_plane = far_plane or float(torch.max(depth))
depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
depth = torch.clip(depth, 0, 1)
# depth = torch.nan_to_num(depth, nan=0.0) # TODO(ethan): remove this
colored_image = apply_colormap(depth, cmap=cmap)
if accumulation is not None:
colored_image = colored_image * accumulation + (1 - accumulation)
return colored_image
def save_points(path_save, pts, colors=None, normals=None, BRG2RGB=False):
"""save points to point cloud using open3d"""
assert len(pts) > 0
if colors is not None:
assert colors.shape[1] == 3
assert pts.shape[1] == 3
cloud = o3d.geometry.PointCloud()
cloud.points = o3d.utility.Vector3dVector(pts)
if colors is not None:
# Open3D assumes the color values are of float type and in range [0, 1]
if np.max(colors) > 1:
colors = colors / np.max(colors)
if BRG2RGB:
colors = np.stack([colors[:, 2], colors[:, 1], colors[:, 0]], axis=-1)
cloud.colors = o3d.utility.Vector3dVector(colors)
if normals is not None:
cloud.normals = o3d.utility.Vector3dVector(normals)
o3d.io.write_point_cloud(path_save, cloud)
def colormap(img, cmap='jet'):
W, H = img.shape[:2]
dpi = 300
fig, ax = plt.subplots(1, figsize=(H/dpi, W/dpi), dpi=dpi)
im = ax.imshow(img, cmap=cmap)
ax.set_axis_off()
fig.colorbar(im, ax=ax)
fig.tight_layout()
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
img = torch.from_numpy(data / 255.).float().permute(2,0,1)
plt.close()
if img.shape[1:] != (H, W):
img = torch.nn.functional.interpolate(img[None], (W, H), mode='bilinear', align_corners=False)[0]
return img