Skip to content

Gm/gen3d/eval script improvements #185

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ assets/kitti/*
__pycache__/
*.py[cod]
docs/*
test_results/
12 changes: 9 additions & 3 deletions scripts/get_ycbv_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,21 +84,27 @@ def get_fp_pred_pose(test_scene_id: int, obj_id: int):

def get_b3d_pred_pose(result_dir: Path, test_scene_id: int, obj_id: int):
poses = np.load(
result_dir / f"SCENE_{test_scene_id}_OBJECT_INDEX_{obj_id}_POSES.npy",
result_dir / f"SCENE_{test_scene_id}_OBJECT_INDEX_{obj_id}_POSES.npy.npz",
)
poses = b3d.Pose(poses["position"], poses["quaternion"])
return poses.as_matrix()


def main(b3d_result_dir: str, output_dir: str | None = None):
def main(b3d_result_dir: str, output_dir: str | None = None, get_fp_pose: bool = False):
"""
Call this with `b3d_result_dir` as the directory containing `.npy.npz` files.
"""
result_dir = Path(b3d_result_dir)
if output_dir is None:
output_dir = b3d_result_dir
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

pred_score_getter = partial(get_b3d_pred_pose, result_dir)
# pred_score_getter = get_fp_pred_pose

if get_fp_pose:
pred_score_getter = get_fp_pred_pose

results_summary, _ = collect_all_scores(pred_score_getter)
print(results_summary)
if output_dir is not None:
Expand Down
83 changes: 64 additions & 19 deletions scripts/run_ycbv_evaluation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,64 @@
#!/usr/bin/env python

import copy
import os
import pprint
from datetime import datetime
from pathlib import Path

import b3d
import b3d.chisight.gen3d.inference as inference
import b3d.chisight.gen3d.settings as settings
import b3d.chisight.gen3d.visualization as viz
import fire
import jax
import jax.numpy as jnp
import rerun as rr
from b3d import Pose
from b3d.chisight.gen3d.dataloading import (
get_initial_state,
load_object_given_scene,
load_scene,
)
from b3d.chisight.gen3d.model import viz_trace as rr_viz_trace
from tqdm import tqdm


def run_tracking(scene=None, object=None, debug=False):
import b3d
def setup_save_directory():
# Make a folder, stamped with the current time.
current_time = datetime.now().strftime("%Y-%m-%d--%H:%M")
folder_name = (
b3d.get_root_path() / "test_results" / "gen3d" / f"gen3d_{current_time}"
)
Path(folder_name).mkdir(parents=True, exist_ok=True)
video_folder_name = folder_name / "mp4"
npy_folder_name = folder_name / "npy"
rr_folder_name = folder_name / "rr"
os.mkdir(rr_folder_name)
os.mkdir(video_folder_name)
os.mkdir(npy_folder_name)
return folder_name, video_folder_name, npy_folder_name, rr_folder_name


def save_hyperparams(folder_name, hyperparams, inference_hyperparams):
hyperparams_file = folder_name / "hyperparams.txt"
with open(hyperparams_file, "w") as f:
f.write("Hyperparameters:\n")
f.write(pprint.pformat(hyperparams))
f.write("\n\n\nInference Hyperparameters:\n")
f.write(pprint.pformat(inference_hyperparams))


def run_tracking(scene=None, object=None, save_rerun=False, max_n_frames=None):
folder_name, video_folder_name, npy_folder_name, rr_folder_name = (
setup_save_directory()
)

FRAME_RATE = 50
hyperparams = copy.deepcopy(settings.hyperparams)
inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams # noqa
save_hyperparams(folder_name, hyperparams, inference_hyperparams)

b3d.utils.rr_init("run_ycbv_evaluation")
FRAME_RATE = 50

if scene is None:
scenes = range(48, 60)
Expand All @@ -30,8 +67,6 @@ def run_tracking(scene=None, object=None, debug=False):
elif isinstance(scene, list):
scenes = scene

hyperparams = copy.deepcopy(settings.hyperparams)

for scene_id in scenes:
all_data, meshes, renderer, intrinsics, initial_object_poses = load_scene(
scene_id, FRAME_RATE
Expand All @@ -53,15 +88,25 @@ def run_tracking(scene=None, object=None, debug=False):
)

tracking_results = {}
inference_hyperparams = b3d.chisight.gen3d.settings.inference_hyperparams # noqa

### Run inference ###
key = jax.random.PRNGKey(156)
trace = inference.get_initial_trace(
key, hyperparams, initial_state, all_data[0]["rgbd"]
)

for T in tqdm(range(len(all_data))):
if save_rerun:
rr.init(f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}")
rr.save(
rr_folder_name / f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.rrd"
)

if max_n_frames is not None:
maxT = min(max_n_frames, len(all_data))
else:
maxT = len(all_data)

for T in tqdm(range(maxT)):
key = b3d.split_key(key)
trace = inference.inference_step_c2f(
key,
Expand All @@ -76,8 +121,8 @@ def run_tracking(scene=None, object=None, debug=False):
)
tracking_results[T] = trace

if debug:
b3d.chisight.gen3d.model.viz_trace(
if save_rerun:
rr_viz_trace(
trace,
T,
ground_truth_vertices=meshes[OBJECT_INDEX].vertices,
Expand All @@ -86,25 +131,25 @@ def run_tracking(scene=None, object=None, debug=False):
)

inferred_poses = Pose.stack_poses(
[
tracking_results[t].get_choices()["pose"]
for t in range(len(all_data))
]
[tracking_results[t].get_choices()["pose"] for t in range(maxT)]
)
jnp.savez(
f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.npy",
npy_folder_name
/ f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}_POSES.npy",
position=inferred_poses.position,
quaternion=inferred_poses.quat,
)

import b3d.chisight.gen3d.visualization as viz

viz.make_video_from_traces(
[tracking_results[t] for t in range(len(all_data))],
f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.mp4",
[tracking_results[t] for t in range(maxT)],
video_folder_name / f"SCENE_{scene_id}_OBJECT_INDEX_{OBJECT_INDEX}.mp4",
scale=0.25,
)

if save_rerun:
rr.disconnect()
print("rerun disconnected")


if __name__ == "__main__":
fire.Fire(run_tracking)
14 changes: 10 additions & 4 deletions src/b3d/chisight/gen3d/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ def get_observed_rgbd(trace):
### Visualization Code ###


def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
def viz_trace(
trace,
t=0,
ground_truth_vertices=None,
ground_truth_pose=None,
log_blueprint=True,
):
b3d.rr_set_time(t)
hyperparams, _ = trace.get_args()
new_state = trace.get_retval()["new_state"]
Expand Down Expand Up @@ -206,9 +212,9 @@ def viz_trace(trace, t=0, ground_truth_vertices=None, ground_truth_pose=None):
b3d.rr_log_pose(ground_truth_pose, "scene/ground_truth_pose")
b3d.rr_log_pose(trace.get_choices()["pose"], "scene/inferred_pose")

# if not b3d.get_blueprint_logged():
# rr.send_blueprint(get_blueprint())
# b3d.set_blueprint_logged(True)
if not b3d.get_blueprint_logged() and log_blueprint:
rr.send_blueprint(get_blueprint())
b3d.set_blueprint_logged(True)


def get_blueprint():
Expand Down
Loading