1
1
#!/usr/bin/env python
2
2
3
3
import copy
4
+ import os
5
+ import pprint
6
+ from datetime import datetime
7
+ from pathlib import Path
4
8
9
+ import b3d
5
10
import b3d .chisight .gen3d .inference as inference
6
11
import b3d .chisight .gen3d .settings as settings
12
+ import b3d .chisight .gen3d .visualization as viz
7
13
import fire
8
14
import jax
9
15
import jax .numpy as jnp
16
+ import rerun as rr
10
17
from b3d import Pose
11
18
from b3d .chisight .gen3d .dataloading import (
12
19
get_initial_state ,
13
20
load_object_given_scene ,
14
21
load_scene ,
15
22
)
23
+ from b3d .chisight .gen3d .model import viz_trace as rr_viz_trace
16
24
from tqdm import tqdm
17
25
18
26
19
- def run_tracking (scene = None , object = None , debug = False ):
20
- import b3d
27
+ def setup_save_directory ():
28
+ # Make a folder, stamped with the current time.
29
+ current_time = datetime .now ().strftime ("%Y-%m-%d--%H:%M" )
30
+ folder_name = (
31
+ b3d .get_root_path () / "test_results" / "gen3d" / f"gen3d_{ current_time } "
32
+ )
33
+ Path (folder_name ).mkdir (parents = True , exist_ok = True )
34
+ video_folder_name = folder_name / "mp4"
35
+ npy_folder_name = folder_name / "npy"
36
+ rr_folder_name = folder_name / "rr"
37
+ os .mkdir (rr_folder_name )
38
+ os .mkdir (video_folder_name )
39
+ os .mkdir (npy_folder_name )
40
+ return folder_name , video_folder_name , npy_folder_name , rr_folder_name
41
+
42
+
43
+ def save_hyperparams (folder_name , hyperparams , inference_hyperparams ):
44
+ hyperparams_file = folder_name / "hyperparams.txt"
45
+ with open (hyperparams_file , "w" ) as f :
46
+ f .write ("Hyperparameters:\n " )
47
+ f .write (pprint .pformat (hyperparams ))
48
+ f .write ("\n \n \n Inference Hyperparameters:\n " )
49
+ f .write (pprint .pformat (inference_hyperparams ))
50
+
51
+
52
+ def run_tracking (scene = None , object = None , save_rerun = False , max_n_frames = None ):
53
+ folder_name , video_folder_name , npy_folder_name , rr_folder_name = (
54
+ setup_save_directory ()
55
+ )
21
56
22
- FRAME_RATE = 50
57
+ hyperparams = copy .deepcopy (settings .hyperparams )
58
+ inference_hyperparams = b3d .chisight .gen3d .settings .inference_hyperparams # noqa
59
+ save_hyperparams (folder_name , hyperparams , inference_hyperparams )
23
60
24
- b3d . utils . rr_init ( "run_ycbv_evaluation" )
61
+ FRAME_RATE = 50
25
62
26
63
if scene is None :
27
64
scenes = range (48 , 60 )
@@ -30,8 +67,6 @@ def run_tracking(scene=None, object=None, debug=False):
30
67
elif isinstance (scene , list ):
31
68
scenes = scene
32
69
33
- hyperparams = copy .deepcopy (settings .hyperparams )
34
-
35
70
for scene_id in scenes :
36
71
all_data , meshes , renderer , intrinsics , initial_object_poses = load_scene (
37
72
scene_id , FRAME_RATE
@@ -53,15 +88,25 @@ def run_tracking(scene=None, object=None, debug=False):
53
88
)
54
89
55
90
tracking_results = {}
56
- inference_hyperparams = b3d .chisight .gen3d .settings .inference_hyperparams # noqa
57
91
58
92
### Run inference ###
59
93
key = jax .random .PRNGKey (156 )
60
94
trace = inference .get_initial_trace (
61
95
key , hyperparams , initial_state , all_data [0 ]["rgbd" ]
62
96
)
63
97
64
- for T in tqdm (range (len (all_data ))):
98
+ if save_rerun :
99
+ rr .init (f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } " )
100
+ rr .save (
101
+ rr_folder_name / f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } .rrd"
102
+ )
103
+
104
+ if max_n_frames is not None :
105
+ maxT = min (max_n_frames , len (all_data ))
106
+ else :
107
+ maxT = len (all_data )
108
+
109
+ for T in tqdm (range (maxT )):
65
110
key = b3d .split_key (key )
66
111
trace = inference .inference_step_c2f (
67
112
key ,
@@ -76,8 +121,8 @@ def run_tracking(scene=None, object=None, debug=False):
76
121
)
77
122
tracking_results [T ] = trace
78
123
79
- if debug :
80
- b3d . chisight . gen3d . model . viz_trace (
124
+ if save_rerun :
125
+ rr_viz_trace (
81
126
trace ,
82
127
T ,
83
128
ground_truth_vertices = meshes [OBJECT_INDEX ].vertices ,
@@ -86,25 +131,25 @@ def run_tracking(scene=None, object=None, debug=False):
86
131
)
87
132
88
133
inferred_poses = Pose .stack_poses (
89
- [
90
- tracking_results [t ].get_choices ()["pose" ]
91
- for t in range (len (all_data ))
92
- ]
134
+ [tracking_results [t ].get_choices ()["pose" ] for t in range (maxT )]
93
135
)
94
136
jnp .savez (
95
- f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } _POSES.npy" ,
137
+ npy_folder_name
138
+ / f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } _POSES.npy" ,
96
139
position = inferred_poses .position ,
97
140
quaternion = inferred_poses .quat ,
98
141
)
99
142
100
- import b3d .chisight .gen3d .visualization as viz
101
-
102
143
viz .make_video_from_traces (
103
- [tracking_results [t ] for t in range (len ( all_data ) )],
104
- f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } .mp4" ,
144
+ [tracking_results [t ] for t in range (maxT )],
145
+ video_folder_name / f"SCENE_{ scene_id } _OBJECT_INDEX_{ OBJECT_INDEX } .mp4" ,
105
146
scale = 0.25 ,
106
147
)
107
148
149
+ if save_rerun :
150
+ rr .disconnect ()
151
+ print ("rerun disconnected" )
152
+
108
153
109
154
if __name__ == "__main__" :
110
155
fire .Fire (run_tracking )
0 commit comments