Skip to content

Commit

Permalink
facebookresearch#355 Add device parameter to UrdfRobotModel (facebook…
Browse files Browse the repository at this point in the history
…research#356)

* Add device parameter to UrdfRobotModel
  • Loading branch information
thomasweng15 authored Nov 8, 2022
1 parent bf0c815 commit 530c871
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
22 changes: 15 additions & 7 deletions tests/embodied/kinematics/test_urdf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@ class VectorType(Enum):
TH_VECTOR = 2


device = "cuda:0" if torch.cuda.is_available() else "cpu"


@pytest.fixture
def robot_model():
urdf_path = os.path.join(os.path.dirname(__file__), URDF_REL_PATH)
return th.eb.UrdfRobotModel(urdf_path)
return th.eb.UrdfRobotModel(urdf_path, device=device)


@pytest.fixture(params=[VectorType.TORCH_TENSOR, VectorType.TH_VECTOR])
Expand All @@ -36,15 +39,17 @@ def dataset(request):

# Input vector type
if request.param == VectorType.TORCH_TENSOR:
joint_states_input = torch.Tensor(data["joint_states"])
joint_states_input = torch.tensor(data["joint_states"], device=device)
elif request.param == VectorType.TH_VECTOR:
joint_states_input = th.Vector(tensor=torch.Tensor(data["joint_states"]))
joint_states_input = th.Vector(
tensor=torch.tensor(data["joint_states"], device=device)
)
else:
raise Exception("Invalid vector type specified.")

# Convert ee poses (from xyzw to wxyz, then from list to tensor)
ee_poses = torch.Tensor(
[pos + quat[3:] + quat[:3] for pos, quat in data["ee_poses"]]
ee_poses = torch.tensor(
[pos + quat[3:] + quat[:3] for pos, quat in data["ee_poses"]], device=device
)

return {
Expand All @@ -65,7 +70,10 @@ def test_forward_kinematics_seq(robot_model, dataset):
ee_se3_computed = robot_model.forward_kinematics(joint_state)[ee_name]

assert torch.allclose(
ee_se3_target.local(ee_se3_computed), torch.zeros(6), atol=1e-5, rtol=1e-4
ee_se3_target.local(ee_se3_computed),
torch.zeros(6, device=device),
atol=1e-5,
rtol=1e-4,
)


Expand All @@ -77,7 +85,7 @@ def test_forward_kinematics_batched(robot_model, dataset):

assert torch.allclose(
ee_se3_target.local(ee_se3_computed),
torch.zeros(dataset["num_data"], 6),
torch.zeros(dataset["num_data"], 6, device=device),
atol=1e-5,
rtol=1e-4,
)
Expand Down
4 changes: 2 additions & 2 deletions theseus/embodied/kinematics/kinematics_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward_kinematics(self, robot_pose: RobotModelInput) -> Dict[str, LieGroup]


class UrdfRobotModel(KinematicsModel):
def __init__(self, urdf_path: str):
def __init__(self, urdf_path: str, device: Optional[str] = None):
try:
import differentiable_robot_model as drm
except ModuleNotFoundError as e:
Expand All @@ -46,7 +46,7 @@ def __init__(self, urdf_path: str):
)
raise e

self.drm_model = drm.DifferentiableRobotModel(urdf_path)
self.drm_model = drm.DifferentiableRobotModel(urdf_path, device=device)

def _postprocess_quaternion(self, quat):
# Convert quaternion convention (DRM uses xyzw, Theseus uses wxyz)
Expand Down

0 comments on commit 530c871

Please sign in to comment.