Skip to content

Commit

Permalink
init-repo-structure-code-broken
Browse files Browse the repository at this point in the history
  • Loading branch information
JonasFrey96 committed Feb 19, 2024
1 parent 363abd1 commit 7cf5b80
Show file tree
Hide file tree
Showing 29 changed files with 816 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .header.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Copyright (c) 2024, MPI-IS, Jonas Frey, Rene Geist, Mikel Zhobro.
All rights reserved. Licensed under the MIT license.
See LICENSE file in the project root for details.
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2024 Autonomous Learning Group
Copyright (c) 2024 Max Planck Institute for Intelligent Systems, Autonomous Learning Group

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
58 changes: 58 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Hitchhiking Rotations

<h4 align="center">
Code for ICML 2024 "Position Paper: Learning with 3D rotations, a hitchhiker’s guide to SO(3)" <a href="some_ariv_link" target="_blank">Paper</a>.</h4>

<p align="center">
<a href="#overview">Overview</a> •
<a href="#results">Results</a> •
<a href="#installation">Installation</a> •
<a href="#experiments">Experiments</a> •
<a href="#development">Development</a> •
<a href="#credits">Credits</a>
</p>


# Overview
(repository overview)

<object data="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf" type="application/pdf" width="700px" height="700px">
<embed src="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf>
<p>This browser does not support PDFs. Please download the PDF to view it: <a href="https://github.com/martius-lab/hitchhiking-rotations/blob/main/assets/docs/torus_v5.pdf">Download PDF</a>.</p>
</embed>
</object>

# Results
(this may be optional)

# Installation
(virtual environment or just list of dependencies)
(using git lsf to get datasets and our checkpoints/models)

```
git clone git@github.com:martius-lab/hitchhiking-rotations.git
pip3 install -e ./
```

# Experiments
List of each experiment as in paper and how to reproduce it

# Development
### Code Formatting
```shell
pip3 install black==23.10
cd hitchhiking_rotations && black --line-length 120 ./
```
### Add License Headers
```shell
pip3 install adheader
# If your are using zsh otherwise remove \
addheader hitchhiking_rotations -t .header.txt -p \*.py --sep-len 79 --comment='#' --sep=' '
```

## TODO
- Add the headers
- Change version to 1.0.0 if done

# Credits

Binary file added assets/docs/torus_v5.pdf
Binary file not shown.
4 changes: 4 additions & 0 deletions hitchhiking_rotations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import os

HITCHHIKING_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
"""Absolute path to the hitchhiking repository."""
1 change: 1 addition & 0 deletions hitchhiking_rotations/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from cube_dataset import CubeImageToPoseDataset, PoseToCubeImageDataset
48 changes: 48 additions & 0 deletions hitchhiking_rotations/datasets/cube_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import pickle
from scipy.spatial.transform import Rotation
import torch
import roma
from torch.utils.data import Dataset


class CubeImageToPoseDataset(Dataset):
def __init__(self, args, device, dataset_file, name):
rots = Rotation.random(args.dataset_size)
quats = rots.as_quat()

self.quats = torch.from_numpy(quats)
self.imgs = []
dataset_file = dataset_file + "_" + name + ".pkl"

if os.path.exists(dataset_file):
dic = pickle.load(open(dataset_file, "rb"))
self.imgs, self.quats = dic["imgs"], dic["quats"]
print("Dataset file exists -> loaded")
else:
from .dataset_generation import DataGenerator

dg = DataGenerator(height=args.height, width=args.width)
for i in range(args.dataset_size):
# TODO normalize data
self.imgs.append(torch.from_numpy(dg.render_img(quats[i])))
dic = {"imgs": self.imgs, "quats": self.quats}
pickle.dump(dic, open(dataset_file, "wb"))
print("Dataset file was created and saved")

self.imgs = [i.to(device) for i in self.imgs]
self.quats = self.quats.to(device)

def __len__(self):
return len(self.imgs)

def __getitem__(self, idx):
return self.imgs[idx].type(torch.float32) / 255, roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32)


class PoseToCubeImageDataset(CubeImageToPoseDataset):
def __init__(self, args, device, dataset_file, name):
super(PoseToCubeImageDataset, self).__init__(args, device, dataset_file, name)

def __getitem__(self, idx):
return roma.unitquat_to_rotmat(self.quats[idx]).type(torch.float32), self.imgs[idx].type(torch.float32) / 255
67 changes: 67 additions & 0 deletions hitchhiking_rotations/datasets/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import mujoco
import torch
import numpy as np
from PIL import Image


class DataGenerator:
def __init__(self, height: int, width: int):
xml = """
<mujoco>
<worldbody>
<light name="top" pos="0 0 0"/>
<body name="cube" euler="0 0 0">
<joint type="ball" stiffness="0" damping="0" frictionloss="0" armature="0"/>
<geom type="box" size="0.1 0.1 0.1" pos="0 0 0" rgba="0.5 0.5 0.5 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 0.9" rgba="1 0 0 1"/>
<geom type="box" size="1 1 0.01" pos="0 0 -0.99" rgba="0 0 1 1"/>
<geom type="box" size="0.01 1 1" pos="0.99 0 0" rgba="0 1 0 1"/>
<geom type="box" size="0.01 1 1" pos="-0.99 0 0" rgba="0 0.6 0.6 1"/>
<geom type="box" size="1 0.01 1" pos="0 0.99 0" rgba="0.6 0.6 0 1"/>
<geom type="box" size="1 0.01 1" pos="0 -0.99 0" rgba="0.6 0 0.6 1"/>
</body>
</worldbody>
</mujoco>
"""
# Make model, data, and renderer
self.mj_model = mujoco.MjModel.from_xml_string(xml)
self.mj_data = mujoco.MjData(self.mj_model)
self.renderer = mujoco.Renderer(self.mj_model, height=height, width=width)

# enable joint visualization option:
self.scene_option = mujoco.MjvOption()
self.scene_option.flags[mujoco.mjtVisFlag.mjVIS_JOINT] = False

def render_img(self, quat: np.array) -> np.array:
"""
Returns image for the body with the specified rotation.
Args:
quat (np.array, shape:=(4) ): scipy format x,y,z,w
"""
mujoco.mj_resetData(self.mj_model, self.mj_data)

# mj_data.qpos = np.random.rand(4)
self.mj_data.qpos = quat

mujoco.mj_forward(self.mj_model, self.mj_data)
self.renderer.update_scene(self.mj_data, scene_option=self.scene_option)
img = self.renderer.render()

return img

def __del__(self):
self.renderer.close()


if __name__ == "__main__":
dg = DataGenerator(64, 64)
img = dg.render_img(np.array([0, 0, 0, 1]))

i1 = Image.fromarray(img)
i1.show()

img = dg.render_img(np.array([0, 1, 0, 1]))

i1 = Image.fromarray(img)
i1.show()
Empty file.
53 changes: 53 additions & 0 deletions hitchhiking_rotations/models/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from torch import nn


class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
)

def forward(x):
return self.model(x)


class CNN(nn.Module):
def __init__(self, rotation_representation_dim, width, height):
super(CNN, self).__init__()
Z_DIM = rotation_representation_dim
IMAGE_CHANNEL = 3
Z_DIM = 10
G_HIDDEN = 64
X_DIM = 64
D_HIDDEN = 64

self.INP_SIZE = 5
self.rotation_representation_dim = rotation_representation_dim
self.inp = nn.Linear(self.rotation_representation_dim, self.INP_SIZE * self.INP_SIZE * 10)
self.seq = nn.Sequential(
# input layer
nn.ConvTranspose2d(Z_DIM, G_HIDDEN * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(G_HIDDEN * 8),
nn.ReLU(True),
# 1st hidden layer
nn.ConvTranspose2d(G_HIDDEN * 8, G_HIDDEN * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(G_HIDDEN * 4),
nn.ReLU(True),
# 2nd hidden layer
nn.ConvTranspose2d(G_HIDDEN * 4, G_HIDDEN * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(G_HIDDEN * 2),
nn.ReLU(True),
# 3rd hidden layer
nn.ConvTranspose2d(G_HIDDEN * 2, IMAGE_CHANNEL, 4, 2, 1, bias=False),
)

def forward(self, x):
x = self.inp(x)
x = self.seq(x.reshape(-1, 10, self.INP_SIZE, self.INP_SIZE))
return x.permute(0, 2, 3, 1)
3 changes: 3 additions & 0 deletions hitchhiking_rotations/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .conversions import get_rotation_representation_dim, to_rotmat, to_rotation_representation
from .euler_helper import euler_angles_to_matrix, matrix_to_euler_angles
from .metrics import chordal_distance, l2_dp_loss, cosine_similarity_loss, chordal_loss, mse_loss
118 changes: 118 additions & 0 deletions hitchhiking_rotations/utils/conversions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from pose_estimation import euler_angles_to_matrix, matrix_to_euler_angles
import roma
import torch


def get_rotation_representation_dim(rotation_representation: str) -> int:
"""
Return dimensionality of rotation representation
Args:
rotation_representation (str): rotation representation identifier
Returns:
int: dimensionality of rotation representation
"""
if rotation_representation == "euler":
rotation_representation_dim = 3
elif rotation_representation == "rotvec":
rotation_representation_dim = 3
elif (
rotation_representation == "quaternion"
or rotation_representation == "quaternion_canonical"
or rotation_representation == "quaternion_rand_flip"
):
rotation_representation_dim = 4

elif rotation_representation == "procrustes":
rotation_representation_dim = 9
elif rotation_representation == "gramschmidt":
rotation_representation_dim = 6
else:
raise ValueError("Unknown rotation representation" + rotation_representation)

return rotation_representation_dim


def to_rotmat(inp: torch.Tensor, rotation_representation: str) -> torch.Tensor:
"""
Supported representations and shapes:
quaternion: N,4 - comment: XYZW
quaternion_canonical: N,4 - comment: XYZW
gramschmidt: N,3,2 -
procrustes: N,3,3 -
rotvec: N,3 -
Args:
inp (torch.tensor, shape=(N,..), dtype=torch.float32): specified rotation representation
rotation_representation (string): rotation representation identifier
Returns:
(torch.tensor, shape=(N,...): SO3 Rotation Matrix
"""

if rotation_representation == "euler":
base = euler_angles_to_matrix(inp.reshape(-1, 3), convention="XZY")

elif (
rotation_representation == "quaternion"
or rotation_representation == "quaternion_canonical"
or rotation_representation == "quaternion_rand_flip"
):
inp = inp.reshape(-1, 4)
# normalize
inp = inp / torch.norm(inp, dim=1, keepdim=True)
base = roma.unitquat_to_rotmat(inp.reshape(-1, 4))

elif rotation_representation == "gramschmidt":
base = roma.special_gramschmidt(inp.reshape(-1, 3, 2))

elif rotation_representation == "procrustes":
base = roma.special_procrustes(inp.reshape(-1, 3, 3))

elif rotation_representation == "rotvec":
base = roma.rotvec_to_rotmat(inp.reshape(-1, 3))

return base


def to_rotation_representation(base: torch.Tensor, rotation_representation: str) -> torch.Tensor:
"""
Quaternion representation is always XYZW
For Euler uses XZY
Args:
base (torch.tensor, shape=(N,3,3), dtype=torch.float32): SO3 Rotation Matrix
rotation_representation (string): rotation representation identifier
Returns:
(torch.tensor, shape=(N,...): Returns selected rotation representation
"""

rotation_representation_dim = get_rotation_representation_dim(rotation_representation)
if rotation_representation == "euler":
rep = matrix_to_euler_angles(base, convention="XZY")

elif rotation_representation == "quaternion":
rep = roma.rotmat_to_unitquat(base)

elif rotation_representation == "quaternion_rand_flip":
rep = roma.rotmat_to_unitquat(base)
rand_flipping = torch.rand(base.shape[0]) > 0.5
rep[rand_flipping] *= -1

elif rotation_representation == "quaternion_canonical":
rep = roma.rotmat_to_unitquat(base)
rep[rep[:, 3] < 0] *= -1

elif rotation_representation == "gramschmidt":
rep = base[:, :, :2]

elif rotation_representation == "procrustes":
rep = base

elif rotation_representation == "rotvec":
rep = roma.rotmat_to_rotvec(base)

return rep.reshape(-1, rotation_representation_dim)
Loading

0 comments on commit 7cf5b80

Please sign in to comment.