Skip to content

Commit

Permalink
Add demo in colab and update README
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-nguyen committed May 17, 2023
1 parent 1b0db7c commit 78eab15
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 56 deletions.
41 changes: 27 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
## template-pose (CVPR 2022) <br><sub>Official PyTorch implementation </sub>
<div align="center">
<h2>
Templates for 3D Object Pose Estimation Revisited:<br> Generalization to New objects and Robustness to Occlusions
<p></p>

<a href="https://nv-nguyen.github.io/" target="_blank"><nobr>Van Nguyen Nguyen</nobr></a> &emsp;
<a href="https://yinlinhu.github.io/" target="_blank"><nobr>Yinlin Hu</nobr></a> &emsp;
<a href="Yang Xiao" target="_blank"><nobr>Yang Xiao</nobr></a> &emsp;
<a href="https://people.epfl.ch/mathieu.salzmann" target="_blank"><nobr>Mathieu Salzmann</nobr></a> &emsp;
<a href="https://vincentlepetit.github.io/" target="_blank"><nobr>Vincent Lepetit</nobr></a>

<p></p>

<a href="https://nv-nguyen.github.io/template-pose/"><img
src="https://img.shields.io/badge/-Webpage-blue.svg?colorA=333&logo=html5" height=35em></a>
<a href="https://arxiv.org/abs/2203.17234"><img
src="https://img.shields.io/badge/-Paper-blue.svg?colorA=333&logo=arxiv" height=35em></a>
<a href="https://colab.research.google.com/drive/18Si4X7fcKFHvFuMS-FRVkDyTvlOsr78H?usp=sharing"><img
src="https://img.shields.io/badge/-Demo-blue.svg?colorA=333&logo=googlecolab" height=35em></a>
<p></p>

![Teaser image](./media/method.png)
<p align="center">
<img src=./media/qualitative.gif width="80%"/>
</p>

**Templates for 3D Object Pose Estimation Revisited: Generalization to New objects and Robustness to Occlusions**<br>
[Van Nguyen Nguyen](https://nv-nguyen.github.io/),
[Yinlin Hu](https://yinlinhu.github.io/),
[Yang Xiao](https://youngxiao13.github.io/),
[Mathieu Salzmann](https://people.epfl.ch/mathieu.salzmann) and
[Vincent Lepetit](https://vincentlepetit.github.io/) <br>
**[Paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Nguyen_Templates_for_3D_Object_Pose_Estimation_Revisited_Generalization_to_New_CVPR_2022_paper.pdf)
, [Project Page](https://nv-nguyen.github.io/template-pose/)**
</h2>
</div>

If our project is helpful for your research, please consider citing :
``` Bash
Expand All @@ -27,9 +42,7 @@ If you like this project, check out related works from our group:
(3DV 2022)](https://github.com/nv-nguyen/pizza)
- [BOP visualization toolkit](https://github.com/nv-nguyen/bop_viz_kit)

<p align="center">
<img src=./media/qualitative.gif width="80%"/>
</p>
![Teaser image](./media/method.png)

## Updates (WIP)
We have introduced additional features and updates to the codebase:
Expand Down Expand Up @@ -208,7 +221,7 @@ python gradio_demo.py

## Acknowledgement

The code is adapted from [Nope](https://github.com/nv-nguyen/nope), [Temos](https://github.com/Mathux/Temos), [PoseContrast](https://github.com/YoungXIAO13/PoseContrast), [CosyPose](https://github.com/ylabbe/cosypose) and [BOP Toolkit](https://github.com/thodan/bop_toolkit).
The code is adapted from [Nope](https://github.com/nv-nguyen/nope), [Temos](https://github.com/Mathux/Temos), [Unicorn](https://github.com/monniert/unicorn), [PoseContrast](https://github.com/YoungXIAO13/PoseContrast), [CosyPose](https://github.com/ylabbe/cosypose) and [BOP Toolkit](https://github.com/thodan/bop_toolkit).

The authors thank Martin Sundermeyer, Paul Wohlhart and Shreyas Hampali for their fast reply, feedback!

Expand Down
93 changes: 67 additions & 26 deletions gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from functools import partial
from PIL import Image
import numpy as np
from src.utils.gradio_utils import CameraVisualizer, calc_cam_cone_pts_3d
import os

import glob
import argparse
from omegaconf import DictConfig, OmegaConf

WEBSITE = """
<h1 style='text-align: center'>Templates for 3D Object Pose Estimation Revisited: <br>
Expand Down Expand Up @@ -37,49 +38,89 @@
"""


def main(model, device, cam_vis, reference_image, query_image, num_neighbors):
if num_neighbors is None: # dirty fix for examples
num_neighbors = 3
# update the number of neighbors
cam_vis.neighbors_change(num_neighbors)
def get_examples(dir):
name_example = [
os.path.join(dir, f)
for f in os.listdir(dir)
if os.path.isdir(os.path.join(dir, f))
]
examples = [] # query, cad
for name in name_example:
query_paths = glob.glob(os.path.join(name, "query*.png"))
for query_path in query_paths:
obj_id = int(os.path.basename(name).split("_")[-1])
cad_path = os.path.join(name, f"obj_{obj_id:06d}.ply")
examples.append([query_path, cad_path])
break
return examples


def call_pyrender(cad_model, is_top_sphere):
from src.poses.pyrender import render
# get template position on the sphere
from src.poses.utils import get_obj_poses_from_template_level
from src.utils.trimesh_utils import get_obj_diameter
poses = get_obj_poses_from_template_level(
level=2, pose_distribution="upper" if is_top_sphere else "all"
)
# normalize meshes
cad_model = get_obj_diameter()
render(
mesh,
output_dir,
obj_poses,
img_size,
intrinsic,
is_tless=False,
)


# calculate predicted pose distribution and neighbors
proba = Image.open("media/demo/proba.png")
elevations = np.random.uniform(-60, 0, size=num_neighbors)
azimuths = np.random.uniform(0, 360, size=num_neighbors)
def main(model, device, query_image, cad_model, is_top_sphere, num_neighbors):
"""
The pipeline is:
1. Rendering posed templates given CAD model
2. Compute descriptors of these templates
3. For each query image, compute its features and find nearest neighbors
"""
print(query_image, cad_model, is_top_sphere, num_neighbors)
# render images from CAD model
templates = call_pyrender(cad_model, is_top_sphere)

# update the figure
cam_vis.polar_change(elevations)
cam_vis.azimuth_change(azimuths)
tmp = np.array(reference_image.convert("RGB"))
cam_vis.encode_image(np.uint8(tmp))
new_fig = cam_vis.update_figure()
return [new_fig, proba]
return templates


def run_demo():
inputs = [
gr.Image(label="query image", type="pil", image_mode="RGB"),
gr.Textbox(label="CAD model", lines=2, placeholder="Path to CAD model (i.e /home/nguyen/Documents/obj_000001.ply"),
gr.Image(label="cropped query image", type="pil", image_mode="RGB"),
gr.Model3D(label="CAD model"),
gr.inputs.Checkbox(label="Templates only on top sphere?", default=False),
gr.Slider(0, 5, value=3, step=1, label="Number of neighbors to show"),
]
vis_output = gr.Plot(label="Predictions")
neighbors_output = gr.Image(label="Nearest neighbor", type="pil", image_mode="RGBA")
output = gr.Gallery(label="Nearest neighbors")
output.style(grid=5)

cam_vis = CameraVisualizer(vis_output)
fn_with_model = partial(main, None, None, cam_vis)
fn_with_model = partial(main, None, None)
fn_with_model.__name__ = "fn_with_model"

examples = get_examples("./media/demo/")
demo = gr.Interface(
fn=fn_with_model,
title=WEBSITE,
inputs=inputs,
outputs=[vis_output, neighbors_output],
outputs=output,
allow_flagging="never",
examples=examples,
cache_examples=True,
)
demo.launch(share=True)


if __name__ == "__main__":
fire.Fire(run_demo)
parser = argparse.ArgumentParser()
parser.add_argument("checkpoint_path", nargs="?", help="Path to the checkpoint")
args = parser.parse_args()
config = OmegaConf.load("configs/model/resnet50.yaml")
print(config)
fire.Fire(run_demo)
# device
device = torch.
33 changes: 17 additions & 16 deletions src/dataloader/bop.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def __getitem__(self, idx):
from src.dataloader.lm_utils import query_real_ids
from torchvision.utils import make_grid, save_image

root_dir = "/gpfsscratch/rech/tvi/uyb58rn/datasets/template-pose-released/datasets"
root_dir = "/home/nguyen/Documents/datasets/template-pose-released/datasets/"
transform_inverse = transforms.Compose(
[
transforms.Normalize(
Expand All @@ -468,35 +468,36 @@ def __getitem__(self, idx):
]
os.makedirs("./tmp", exist_ok=True)
for idx_dataset in range(len(root_dirs)):
for obj_id in range(1, 31):
for mode in ["query", "template"]:
for obj_id in [21]:
for mode in ["query"]:
dataset = BOPDatasetTest(
root_dir=root_dirs[idx_dataset],
template_dir=os.path.join(root_dir, f"templates_pyrender/tless"),
split="test_primesense",
obj_id=obj_id,
img_size=256,
reset_metaData=False,
linemod_setting=False,
reset_metaData=True,
linemod_setting=True,
mode=mode,
)

train_data = DataLoader(
dataset, batch_size=36, shuffle=True, num_workers=8
)
train_size, train_loader = len(train_data), iter(train_data)
logging.info(f"object {obj_id}, mode {mode}, length {train_size}")
for idx in tqdm(range(train_size)):
batch = next(train_loader)
# train_data = DataLoader(
# dataset, batch_size=36, shuffle=True, num_workers=8
# )
# train_size, train_loader = len(train_data), iter(train_data)
# logging.info(f"object {obj_id}, mode {mode}, length {train_size}")
for idx in tqdm(range(len(dataset))):
# batch = next(train_loader)
save_image_path = os.path.join(
f"./tmp/obj{obj_id}_{mode}_batch{idx}.png"
f"./media/demo/tless_{obj_id:02d}/query_{idx}.png"
)
rgb = batch[mode]
sample = dataset[idx]
rgb = sample["query"]
save_image(
transform_inverse(rgb),
save_image_path,
nrow=6,
nrow=1,
)
print(save_image_path)
if idx == 2:
if idx == 5:
break

0 comments on commit 78eab15

Please sign in to comment.