-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
363abd1
commit 7cf5b80
Showing
29 changed files
with
816 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
Copyright (c) 2024, MPI-IS, Jonas Frey, Rene Geist, Mikel Zhobro. | ||
All rights reserved. Licensed under the MIT license. | ||
See LICENSE file in the project root for details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Hitchhiking Rotations | ||
|
||
<h4 align="center"> | ||
Code for ICML 2024 "Position Paper: Learning with 3D rotations, a hitchhiker’s guide to SO(3)" <a href="some_ariv_link" target="_blank">Paper</a>.</h4> | ||
|
||
<p align="center"> | ||
<a href="#overview">Overview</a> • | ||
<a href="#results">Results</a> • | ||
<a href="#installation">Installation</a> • | ||
<a href="#experiments">Experiments</a> • | ||
<a href="#development">Development</a> • | ||
<a href="#credits">Credits</a> | ||
</p> | ||
|
||
|
||
# Overview | ||
(repository overview) | ||
|
||
<object data="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf" type="application/pdf" width="700px" height="700px"> | ||
<embed src="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf> | ||
<p>This browser does not support PDFs. Please download the PDF to view it: <a href="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf">Download PDF</a>.</p> | ||
</embed> | ||
</object> | ||
|
||
# Results | ||
(this may be optional) | ||
|
||
# Installation | ||
(virtual environment or just list of dependencies) | ||
(using git lsf to get datasets and our checkpoints/models) | ||
|
||
``` | ||
git clone git@github.com:martius-lab/hitchhiking-rotations.git | ||
pip3 install -e ./ | ||
``` | ||
|
||
# Experiments | ||
List of each experiment as in paper and how to reproduce it | ||
|
||
# Development | ||
### Code Formatting | ||
```shell | ||
pip3 install black==23.10 | ||
cd hitchhiking_rotations && black --line-length 120 ./ | ||
``` | ||
### Add License Headers | ||
```shell | ||
pip3 install adheader | ||
# If your are using zsh otherwise remove \ | ||
addheader hitchhiking_rotations -t .header.txt -p \*.py --sep-len 79 --comment='#' --sep=' ' | ||
``` | ||
|
||
## TODO | ||
- Add the headers | ||
- Change version to 1.0.0 if done | ||
|
||
# Credits | ||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
import os | ||
|
||
HITCHHIKING_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||
"""Absolute path to the hitchhiking repository.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from cube_dataset import CubeImageToPoseDataset, PoseToCubeImageDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import pickle | ||
from scipy.spatial.transform import Rotation | ||
import torch | ||
import roma | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class CubeImageToPoseDataset(Dataset): | ||
def __init__(self, args, device, dataset_file, name): | ||
rots = Rotation.random(args.dataset_size) | ||
quats = rots.as_quat() | ||
|
||
self.quats = torch.from_numpy(quats) | ||
self.imgs = [] | ||
dataset_file = dataset_file + "_" + name + ".pkl" | ||
|
||
if os.path.exists(dataset_file): | ||
dic = pickle.load(open(dataset_file, "rb")) | ||
self.imgs, self.quats = dic["imgs"], dic["quats"] | ||
print("Dataset file exists -> loaded") | ||
else: | ||
from .dataset_generation import DataGenerator | ||
|
||
dg = DataGenerator(height=args.height, width=args.width) | ||
for i in range(args.dataset_size): | ||
# TODO normalize data | ||
self.imgs.append(torch.from_numpy(dg.render_img(quats[i]))) | ||
dic = {"imgs": self.imgs, "quats": self.quats} | ||
pickle.dump(dic, open(dataset_file, "wb")) | ||
print("Dataset file was created and saved") | ||
|
||
self.imgs = [i.to(device) for i in self.imgs] | ||
self.quats = self.quats.to(device) | ||
|
||
def __len__(self): | ||
return len(self.imgs) | ||
|
||
def __getitem__(self, idx): | ||
return self.imgs[idx].type(torch.float32) / 255, roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32) | ||
|
||
|
||
class PoseToCubeImageDataset(CubeImageToPoseDataset): | ||
def __init__(self, args, device, dataset_file, name): | ||
super(PoseToCubeImageDataset, self).__init__(args, device, dataset_file, name) | ||
|
||
def __getitem__(self, idx): | ||
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.imgs[idx].type(torch.float32) / 255 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import mujoco | ||
import torch | ||
import numpy as np | ||
from PIL import Image | ||
|
||
|
||
class DataGenerator: | ||
def __init__(self, height: int, width: int): | ||
xml = """ | ||
<mujoco> | ||
<worldbody> | ||
<light name="top" pos="0 0 0"/> | ||
<body name="cube" euler="0 0 0"> | ||
<joint type="ball" stiffness="0" damping="0" frictionloss="0" armature="0"/> | ||
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/> | ||
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/> | ||
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/> | ||
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/> | ||
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/> | ||
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/> | ||
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/> | ||
</body> | ||
</worldbody> | ||
</mujoco> | ||
""" | ||
# Make model, data, and renderer | ||
self.mj_model = mujoco.MjModel.from_xml_string(xml) | ||
self.mj_data = mujoco.MjData(self.mj_model) | ||
self.renderer = mujoco.Renderer(self.mj_model, height=height, width=width) | ||
|
||
# enable joint visualization option: | ||
self.scene_option = mujoco.MjvOption() | ||
self.scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = False | ||
|
||
def render_img(self, quat: np.array) -> np.array: | ||
""" | ||
Returns image for the body with the specified rotation. | ||
Args: | ||
quat (np.array, shape:=(4) ): scipy format x,y,z,w | ||
""" | ||
mujoco.mj_resetData(self.mj_model, self.mj_data) | ||
|
||
# mj_data.qpos = np.random.rand(4) | ||
self.mj_data.qpos = quat | ||
|
||
mujoco.mj_forward(self.mj_model, self.mj_data) | ||
self.renderer.update_scene(self.mj_data, scene_option=self.scene_option) | ||
img = self.renderer.render() | ||
|
||
return img | ||
|
||
def __del__(self): | ||
self.renderer.close() | ||
|
||
|
||
if __name__ == "__main__": | ||
dg = DataGenerator(64, 64) | ||
img = dg.render_img(np.array([0, 0, 0, 1])) | ||
|
||
i1 = Image.fromarray(img) | ||
i1.show() | ||
|
||
img = dg.render_img(np.array([0, 1, 0, 1])) | ||
|
||
i1 = Image.fromarray(img) | ||
i1.show() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class MLP(nn.Module): | ||
def __init__(self, input_dim, output_dim): | ||
super(MLP, self).__init__() | ||
self.model = nn.Sequential( | ||
nn.Linear(input_dim, 256), | ||
nn.ReLU(), | ||
nn.Linear(256, 256), | ||
nn.ReLU(), | ||
nn.Linear(256, output_dim), | ||
) | ||
|
||
def forward(x): | ||
return self.model(x) | ||
|
||
|
||
class CNN(nn.Module): | ||
def __init__(self, rotation_representation_dim, width, height): | ||
super(CNN, self).__init__() | ||
Z_DIM = rotation_representation_dim | ||
IMAGE_CHANNEL = 3 | ||
Z_DIM = 10 | ||
G_HIDDEN = 64 | ||
X_DIM = 64 | ||
D_HIDDEN = 64 | ||
|
||
self.INP_SIZE = 5 | ||
self.rotation_representation_dim = rotation_representation_dim | ||
self.inp = nn.Linear(self.rotation_representation_dim, self.INP_SIZE * self.INP_SIZE * 10) | ||
self.seq = nn.Sequential( | ||
# input layer | ||
nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False), | ||
nn.BatchNorm2d(G_HIDDEN * 8), | ||
nn.ReLU(True), | ||
# 1st hidden layer | ||
nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(G_HIDDEN * 4), | ||
nn.ReLU(True), | ||
# 2nd hidden layer | ||
nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False), | ||
nn.BatchNorm2d(G_HIDDEN * 2), | ||
nn.ReLU(True), | ||
# 3rd hidden layer | ||
nn.ConvTranspose2d(G_HIDDEN * 2, IMAGE_CHANNEL, 4, 2, 1, bias=False), | ||
) | ||
|
||
def forward(self, x): | ||
x = self.inp(x) | ||
x = self.seq(x.reshape(-1, 10, self.INP_SIZE, self.INP_SIZE)) | ||
return x.permute(0, 2, 3, 1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .conversions import get_rotation_representation_dim, to_rotmat, to_rotation_representation | ||
from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles | ||
from .metrics import chordal_distance, l2_dp_loss, cosine_similarity_loss, chordal_loss, mse_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from pose_estimation import euler_angles_to_matrix, matrix_to_euler_angles | ||
import roma | ||
import torch | ||
|
||
|
||
def get_rotation_representation_dim(rotation_representation: str) -> int: | ||
""" | ||
Return dimensionality of rotation representation | ||
Args: | ||
rotation_representation (str): rotation representation identifier | ||
Returns: | ||
int: dimensionality of rotation representation | ||
""" | ||
if rotation_representation == "euler": | ||
rotation_representation_dim = 3 | ||
elif rotation_representation == "rotvec": | ||
rotation_representation_dim = 3 | ||
elif ( | ||
rotation_representation == "quaternion" | ||
or rotation_representation == "quaternion_canonical" | ||
or rotation_representation == "quaternion_rand_flip" | ||
): | ||
rotation_representation_dim = 4 | ||
|
||
elif rotation_representation == "procrustes": | ||
rotation_representation_dim = 9 | ||
elif rotation_representation == "gramschmidt": | ||
rotation_representation_dim = 6 | ||
else: | ||
raise ValueError("Unknown rotation representation" + rotation_representation) | ||
|
||
return rotation_representation_dim | ||
|
||
|
||
def to_rotmat(inp: torch.Tensor, rotation_representation: str) -> torch.Tensor: | ||
""" | ||
Supported representations and shapes: | ||
quaternion: N,4 - comment: XYZW | ||
quaternion_canonical: N,4 - comment: XYZW | ||
gramschmidt: N,3,2 - | ||
procrustes: N,3,3 - | ||
rotvec: N,3 - | ||
Args: | ||
inp (torch.tensor, shape=(N,..), dtype=torch.float32): specified rotation representation | ||
rotation_representation (string): rotation representation identifier | ||
Returns: | ||
(torch.tensor, shape=(N,...): SO3 Rotation Matrix | ||
""" | ||
|
||
if rotation_representation == "euler": | ||
base = euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY") | ||
|
||
elif ( | ||
rotation_representation == "quaternion" | ||
or rotation_representation == "quaternion_canonical" | ||
or rotation_representation == "quaternion_rand_flip" | ||
): | ||
inp = inp.reshape(-1, 4) | ||
# normalize | ||
inp = inp / torch.norm(inp, dim=1, keepdim=True) | ||
base = roma.unitquat_to_rotmat(inp.reshape(-1, 4)) | ||
|
||
elif rotation_representation == "gramschmidt": | ||
base = roma.special_gramschmidt(inp.reshape(-1, 3, 2)) | ||
|
||
elif rotation_representation == "procrustes": | ||
base = roma.special_procrustes(inp.reshape(-1, 3, 3)) | ||
|
||
elif rotation_representation == "rotvec": | ||
base = roma.rotvec_to_rotmat(inp.reshape(-1, 3)) | ||
|
||
return base | ||
|
||
|
||
def to_rotation_representation(base: torch.Tensor, rotation_representation: str) -> torch.Tensor: | ||
""" | ||
Quaternion representation is always XYZW | ||
For Euler uses XZY | ||
Args: | ||
base (torch.tensor, shape=(N,3,3), dtype=torch.float32): SO3 Rotation Matrix | ||
rotation_representation (string): rotation representation identifier | ||
Returns: | ||
(torch.tensor, shape=(N,...): Returns selected rotation representation | ||
""" | ||
|
||
rotation_representation_dim = get_rotation_representation_dim(rotation_representation) | ||
if rotation_representation == "euler": | ||
rep = matrix_to_euler_angles(base, convention="XZY") | ||
|
||
elif rotation_representation == "quaternion": | ||
rep = roma.rotmat_to_unitquat(base) | ||
|
||
elif rotation_representation == "quaternion_rand_flip": | ||
rep = roma.rotmat_to_unitquat(base) | ||
rand_flipping = torch.rand(base.shape[0]) > 0.5 | ||
rep[rand_flipping] *= -1 | ||
|
||
elif rotation_representation == "quaternion_canonical": | ||
rep = roma.rotmat_to_unitquat(base) | ||
rep[rep[:, 3] < 0] *= -1 | ||
|
||
elif rotation_representation == "gramschmidt": | ||
rep = base[:, :, :2] | ||
|
||
elif rotation_representation == "procrustes": | ||
rep = base | ||
|
||
elif rotation_representation == "rotvec": | ||
rep = roma.rotmat_to_rotvec(base) | ||
|
||
return rep.reshape(-1, rotation_representation_dim) |
Oops, something went wrong.