Skip to content

Commit f4a001c

Browse files
Gm/gen3d/eval script improvements (#185)
1 parent 90602d6 commit f4a001c

File tree

4 files changed

+84
-26
lines changed

4 files changed

+84
-26
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ assets/kitti/*
2424
__pycache__/
2525
*.py[cod]
2626
docs/*
27+
test_results/

scripts/get_ycbv_metrics.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,27 @@ def get_fp_pred_pose(test_scene_id: int, obj_id: int):
8484

8585
def get_b3d_pred_pose(result_dir: Path, test_scene_id: int, obj_id: int):
8686
poses = np.load(
87-
result_dir / f"SCENE_{test_scene_id}_OBJECT_INDEX_{obj_id}_POSES.npy",
87+
result_dir / f"SCENE_{test_scene_id}_OBJECT_INDEX_{obj_id}_POSES.npy.npz",
8888
)
8989
poses = b3d.Pose(poses["position"], poses["quaternion"])
9090
return poses.as_matrix()
9191

9292

93-
def main(b3d_result_dir: str, output_dir: str | None = None):
93+
def main(b3d_result_dir: str, output_dir: str | None = None, get_fp_pose: bool = False):
94+
"""
95+
Call this with `b3d_result_dir` as the directory containing `.npy.npz` files.
96+
"""
9497
result_dir = Path(b3d_result_dir)
9598
if output_dir is None:
9699
output_dir = b3d_result_dir
97100
output_dir = Path(output_dir)
98101
output_dir.mkdir(parents=True, exist_ok=True)
99102

100103
pred_score_getter = partial(get_b3d_pred_pose, result_dir)
101-
# pred_score_getter = get_fp_pred_pose
104+
105+
if get_fp_pose:
106+
pred_score_getter = get_fp_pred_pose
107+
102108
results_summary, _ = collect_all_scores(pred_score_getter)
103109
print(results_summary)
104110
if output_dir is not None:

scripts/run_ycbv_evaluation.py

+64-19
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,64 @@
11
#!/usr/bin/env python
22

33
import copy
4+
import os
5+
import pprint
6+
from datetime import datetime
7+
from pathlib import Path
48

9+
import b3d
510
import b3d.chisight.gen3d.inference as inference
611
import b3d.chisight.gen3d.settings as settings
12+
import b3d.chisight.gen3d.visualization as viz
713
import fire
814
import jax
915
import jax.numpy as jnp
16+
import rerun as rr
1017
from b3d import Pose
1118
from b3d.chisight.gen3d.dataloading import (
1219
get_initial_state,
1320
load_object_given_scene,
1421
load_scene,
1522
)
23+
from b3d.chisight.gen3d.model import viz_trace as rr_viz_trace
1624
from tqdm import tqdm
1725

1826

19-
def run_tracking(scene=None, object=None, debug=False):
20-
import b3d
27+
def setup_save_directory():
28+
# Make a folder, stamped with the current time.
29+
current_time = datetime.now().strftime("%Y-%m-%d--%H:%M")
30+
folder_name = (
31+
b3d.get_root_path() / "test_results" / "gen3d" / f"gen3d_{current_time}"
32+
)
33+
Path(folder_name).mkdir(parents=True, exist_ok=True)
34+
video_folder_name = folder_name / "mp4"
35+
npy_folder_name = folder_name / "npy"
36+
rr_folder_name = folder_name / "rr"
37+
os.mkdir(rr_folder_name)
38+
os.mkdir(video_folder_name)
39+
os.mkdir(npy_folder_name)
40+
return folder_name, video_folder_name, npy_folder_name, rr_folder_name
41+
42+
43+
def save_hyperparams(folder_name, hyperparams, inference_hyperparams):
44+
hyperparams_file = folder_name / "hyperparams.txt"
45+
with open(hyperparams_file, "w") as f:
46+
f.write("Hyperparameters:\n")
47+
f.write(pprint.pformat(hyperparams))
48+
f.write("\n\n\nInference Hyperparameters:\n")
49+
f.write(pprint.pformat(inference_hyperparams))
50+
51+
52+
def run_tracking(scene=None, object=None, save_rerun=False, max_n_frames=None):
53+
folder_name, video_folder_name, npy_folder_name, rr_folder_name = (
54+
setup_save_directory()
55+
)
2156

22-
FRAME_RATE = 50
57+
hyperparams = copy.deepcopy(settings.hyperparams)
58+
inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams # noqa
59+
save_hyperparams(folder_name, hyperparams, inference_hyperparams)
2360

24-
b3d.utils.rr_init("run_ycbv_evaluation")
61+
FRAME_RATE = 50
2562

2663
if scene is None:
2764
scenes = range(48, 60)
@@ -30,8 +67,6 @@ def run_tracking(scene=None, object=None, debug=False):
3067
elif isinstance(scene, list):
3168
scenes = scene
3269

33-
hyperparams = copy.deepcopy(settings.hyperparams)
34-
3570
for scene_id in scenes:
3671
all_data, meshes, renderer, intrinsics, initial_object_poses = load_scene(
3772
scene_id, FRAME_RATE
@@ -53,15 +88,25 @@ def run_tracking(scene=None, object=None, debug=False):
5388
)
5489

5590
tracking_results = {}
56-
inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams # noqa
5791

5892
### Run inference ###
5993
key = jax.random.PRNGKey(156)
6094
trace = inference.get_initial_trace(
6195
key, hyperparams, initial_state, all_data[0]["rgbd"]
6296
)
6397

64-
for T in tqdm(range(len(all_data))):
98+
if save_rerun:
99+
rr.init(f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}")
100+
rr.save(
101+
rr_folder_name / f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.rrd"
102+
)
103+
104+
if max_n_frames is not None:
105+
maxT = min(max_n_frames, len(all_data))
106+
else:
107+
maxT = len(all_data)
108+
109+
for T in tqdm(range(maxT)):
65110
key = b3d.split_key(key)
66111
trace = inference.inference_step_c2f(
67112
key,
@@ -76,8 +121,8 @@ def run_tracking(scene=None, object=None, debug=False):
76121
)
77122
tracking_results[T] = trace
78123

79-
if debug:
80-
b3d.chisight.gen3d.model.viz_trace(
124+
if save_rerun:
125+
rr_viz_trace(
81126
trace,
82127
T,
83128
ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
@@ -86,25 +131,25 @@ def run_tracking(scene=None, object=None, debug=False):
86131
)
87132

88133
inferred_poses = Pose.stack_poses(
89-
[
90-
tracking_results[t].get_choices()["pose"]
91-
for t in range(len(all_data))
92-
]
134+
[tracking_results[t].get_choices()["pose"] for t in range(maxT)]
93135
)
94136
jnp.savez(
95-
f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.npy",
137+
npy_folder_name
138+
/ f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.npy",
96139
position=inferred_poses.position,
97140
quaternion=inferred_poses.quat,
98141
)
99142

100-
import b3d.chisight.gen3d.visualization as viz
101-
102143
viz.make_video_from_traces(
103-
[tracking_results[t] for t in range(len(all_data))],
104-
f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.mp4",
144+
[tracking_results[t] for t in range(maxT)],
145+
video_folder_name / f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.mp4",
105146
scale=0.25,
106147
)
107148

149+
if save_rerun:
150+
rr.disconnect()
151+
print("rerun disconnected")
152+
108153

109154
if __name__ == "__main__":
110155
fire.Fire(run_tracking)

src/b3d/chisight/gen3d/model.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,13 @@ def get_observed_rgbd(trace):
9999
### Visualization Code ###
100100

101101

102-
def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
102+
def viz_trace(
103+
trace,
104+
t=0,
105+
ground_truth_vertices=None,
106+
ground_truth_pose=None,
107+
log_blueprint=True,
108+
):
103109
b3d.rr_set_time(t)
104110
hyperparams, _ = trace.get_args()
105111
new_state = trace.get_retval()["new_state"]
@@ -206,9 +212,9 @@ def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
206212
b3d.rr_log_pose(ground_truth_pose, "scene/ground_truth_pose")
207213
b3d.rr_log_pose(trace.get_choices()["pose"], "scene/inferred_pose")
208214

209-
# if not b3d.get_blueprint_logged():
210-
# rr.send_blueprint(get_blueprint())
211-
# b3d.set_blueprint_logged(True)
215+
if not b3d.get_blueprint_logged() and log_blueprint:
216+
rr.send_blueprint(get_blueprint())
217+
b3d.set_blueprint_logged(True)
212218

213219

214220
def get_blueprint():

0 commit comments

Comments
 (0)