Skip to content

Commit

Permalink
handle multiple clients
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Jun 29, 2021
1 parent 46b909c commit ec7454a
Showing 1 changed file with 46 additions and 41 deletions.
87 changes: 46 additions & 41 deletions panda_gym/pybullet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pybullet as p
import pybullet_utils.bullet_client as bc
import pybullet_data


Expand All @@ -24,18 +25,18 @@ def __init__(self, render=False, n_substeps=20, background_color=(116, 160, 216)
--background_color_blue={}".format(
*self.background_color
)
p.connect(p.GUI, options=options)
p.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
p.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING, 0)
self.physics_client = bc.BulletClient(connection_mode=p.GUI, options=options)
self.physics_client.configureDebugVisualizer(p.COV_ENABLE_GUI, 0)
self.physics_client.configureDebugVisualizer(p.COV_ENABLE_MOUSE_PICKING, 0)
else:
p.connect(p.DIRECT)
self.physics_client = bc.BulletClient(connection_mode=p.DIRECT)

self.n_substeps = n_substeps
self.timestep = 1.0 / 500
p.setTimeStep(self.timestep)
p.resetSimulation()
p.setAdditionalSearchPath(pybullet_data.getDataPath())
p.setGravity(0, 0, -9.81)
self.physics_client.setTimeStep(self.timestep)
self.physics_client.resetSimulation()
self.physics_client.setAdditionalSearchPath(pybullet_data.getDataPath())
self.physics_client.setGravity(0, 0, -9.81)
self._bodies_idx = {}

@property
Expand All @@ -46,7 +47,7 @@ def dt(self):
def step(self):
"""Step the simulation."""
for _ in range(self.n_substeps):
p.stepSimulation()
self.physics_client.stepSimulation()

def close(self):
"""Close the simulation."""
Expand Down Expand Up @@ -85,19 +86,21 @@ def render(
An RGB array if mode is 'rgb_array'.
"""
if mode == "human":
p.configureDebugVisualizer(p.COV_ENABLE_SINGLE_STEP_RENDERING)
self.physics_client.configureDebugVisualizer(self.physics_client.COV_ENABLE_SINGLE_STEP_RENDERING)
time.sleep(self.dt) # wait to seems like real speed
if mode == "rgb_array":
view_matrix = p.computeViewMatrixFromYawPitchRoll(
view_matrix = self.physics_client.computeViewMatrixFromYawPitchRoll(
cameraTargetPosition=target_position,
distance=distance,
yaw=yaw,
pitch=pitch,
roll=roll,
upAxisIndex=2,
)
proj_matrix = p.computeProjectionMatrixFOV(fov=60, aspect=float(width) / height, nearVal=0.1, farVal=100.0)
(_, _, px, depth, _) = p.getCameraImage(
proj_matrix = self.physics_client.computeProjectionMatrixFOV(
fov=60, aspect=float(width) / height, nearVal=0.1, farVal=100.0
)
(_, _, px, depth, _) = self.physics_client.getCameraImage(
width=width,
height=height,
viewMatrix=view_matrix,
Expand Down Expand Up @@ -130,7 +133,7 @@ def get_base_position(self, body):
Returns:
(x, y, z): The cartesian position.
"""
return p.getBasePositionAndOrientation(self._bodies_idx[body])[0]
return self.physics_client.getBasePositionAndOrientation(self._bodies_idx[body])[0]

def get_base_orientation(self, body):
"""Get the orientation of the body.
Expand All @@ -141,7 +144,7 @@ def get_base_orientation(self, body):
Returns:
(x, y, z, w): The orientation as quaternion.
"""
return p.getBasePositionAndOrientation(self._bodies_idx[body])[1]
return self.physics_client.getBasePositionAndOrientation(self._bodies_idx[body])[1]

def get_base_rotation(self, body):
"""Get the rotation of the body.
Expand All @@ -152,7 +155,7 @@ def get_base_rotation(self, body):
Returns:
(rx, ry, rz): The rotation.
"""
return p.getEulerFromQuaternion(self.get_base_orientation(body))
return self.physics_client.getEulerFromQuaternion(self.get_base_orientation(body))

def get_base_velocity(self, body):
"""Get the velocity of the body.
Expand All @@ -163,7 +166,7 @@ def get_base_velocity(self, body):
Returns:
(vx, vy, vz): The cartesian velocity.
"""
return p.getBaseVelocity(self._bodies_idx[body])[0]
return self.physics_client.getBaseVelocity(self._bodies_idx[body])[0]

def get_base_angular_velocity(self, body):
"""Get the angular velocity of the body.
Expand All @@ -174,7 +177,7 @@ def get_base_angular_velocity(self, body):
Returns:
(wx, wy, wz): The angular velocity.
"""
return p.getBaseVelocity(self._bodies_idx[body])[1]
return self.physics_client.getBaseVelocity(self._bodies_idx[body])[1]

def get_link_position(self, body, link):
"""Get the position of the link of the body.
Expand All @@ -186,7 +189,7 @@ def get_link_position(self, body, link):
Returns:
(x, y, z): The cartesian position.
"""
return p.getLinkState(self._bodies_idx[body], link)[0]
return self.physics_client.getLinkState(self._bodies_idx[body], link)[0]

def get_link_orientation(self, body, link):
"""Get the orientation of the link of the body.
Expand All @@ -198,7 +201,7 @@ def get_link_orientation(self, body, link):
Returns:
(x, y, z, w): The orientation as quaternion.
"""
return p.getLinkState(self._bodies_idx[body], link)[1]
return self.physics_client.getLinkState(self._bodies_idx[body], link)[1]

def get_link_velocity(self, body, link):
"""Get the velocity of the link of the body.
Expand All @@ -210,7 +213,7 @@ def get_link_velocity(self, body, link):
Returns:
(vx, vy, vz): The cartesian velocity.
"""
return p.getLinkState(self._bodies_idx[body], link, computeLinkVelocity=True)[6]
return self.physics_client.getLinkState(self._bodies_idx[body], link, computeLinkVelocity=True)[6]

def get_link_angular_velocity(self, body, link):
"""Get the angular velocity of the link of the body.
Expand All @@ -222,7 +225,7 @@ def get_link_angular_velocity(self, body, link):
Returns:
(wx, wy, wz): The angular velocity.
"""
return p.getLinkState(self._bodies_idx[body], link, computeLinkVelocity=True)[7]
return self.physics_client.getLinkState(self._bodies_idx[body], link, computeLinkVelocity=True)[7]

def get_joint_angle(self, body, joint):
"""Get the angle of the joint of the body.
Expand All @@ -234,7 +237,7 @@ def get_joint_angle(self, body, joint):
Returns:
float: The angle.
"""
return p.getJointState(self._bodies_idx[body], joint)[0]
return self.physics_client.getJointState(self._bodies_idx[body], joint)[0]

def set_base_pose(self, body, position, orientation):
"""Set the position of the body.
Expand All @@ -244,7 +247,9 @@ def set_base_pose(self, body, position, orientation):
position (x, y, z): The target cartesian position.
orientation (x, y, z, w): The target orientation as quaternion.
"""
p.resetBasePositionAndOrientation(bodyUniqueId=self._bodies_idx[body], posObj=position, ornObj=orientation)
self.physics_client.resetBasePositionAndOrientation(
bodyUniqueId=self._bodies_idx[body], posObj=position, ornObj=orientation
)

def set_joint_angles(self, body, joints, angles):
"""Set the angles of the joints of the body.
Expand All @@ -265,7 +270,7 @@ def set_joint_angle(self, body, joint, angle):
joint (int): Joint index in the body.
angle (float): Target angle.
"""
p.resetJointState(bodyUniqueId=self._bodies_idx[body], jointIndex=joint, targetValue=angle)
self.physics_client.resetJointState(bodyUniqueId=self._bodies_idx[body], jointIndex=joint, targetValue=angle)

def control_joints(
self,
Expand All @@ -282,10 +287,10 @@ def control_joints(
target_angles (List[float]): List of target angles.
forces (List[float]): Forces to apply.
"""
p.setJointMotorControlArray(
self.physics_client.setJointMotorControlArray(
self._bodies_idx[body],
jointIndices=joints,
controlMode=p.POSITION_CONTROL,
controlMode=self.physics_client.POSITION_CONTROL,
targetPositions=target_angles,
forces=forces,
)
Expand All @@ -302,7 +307,7 @@ def inverse_kinematics(self, body, ee_link, position, orientation):
Returns:
List[float]: The new joint state.
"""
return p.calculateInverseKinematics(
return self.physics_client.calculateInverseKinematics(
bodyIndex=self._bodies_idx[body],
endEffectorLinkIndex=ee_link,
targetPosition=position,
Expand All @@ -318,7 +323,7 @@ def place_visualizer(self, target, distance, yaw, pitch):
yaw (float): Yaw.
pitch (float): Pitch.
"""
p.resetDebugVisualizerCamera(
self.physics_client.resetDebugVisualizerCamera(
cameraDistance=distance,
cameraYaw=yaw,
cameraPitch=pitch,
Expand All @@ -328,17 +333,17 @@ def place_visualizer(self, target, distance, yaw, pitch):
@contextmanager
def no_rendering(self):
"""Disable rendering within this context."""
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 0)
self.physics_client.configureDebugVisualizer(self.physics_client.COV_ENABLE_RENDERING, 0)
yield
p.configureDebugVisualizer(p.COV_ENABLE_RENDERING, 1)
self.physics_client.configureDebugVisualizer(self.physics_client.COV_ENABLE_RENDERING, 1)

def loadURDF(self, body_name, **kwargs):
"""Load URDF file.
Args:
body_name (str): The name of the body. Must be unique in the sim.
"""
self._bodies_idx[body_name] = p.loadURDF(**kwargs)
self._bodies_idx[body_name] = self.physics_client.loadURDF(**kwargs)

def create_box(
self,
Expand Down Expand Up @@ -372,7 +377,7 @@ def create_box(
collision_kwargs = {"halfExtents": half_extents}
return self._create_geometry(
body_name,
geom_type=p.GEOM_BOX,
geom_type=self.physics_client.GEOM_BOX,
mass=mass,
position=position,
ghost=ghost,
Expand Down Expand Up @@ -416,7 +421,7 @@ def create_cylinder(
collision_kwargs = {"radius": radius, "height": height}
self._create_geometry(
body_name,
geom_type=p.GEOM_CYLINDER,
geom_type=self.physics_client.GEOM_CYLINDER,
mass=mass,
position=position,
ghost=ghost,
Expand Down Expand Up @@ -457,7 +462,7 @@ def create_sphere(
collision_kwargs = {"radius": radius}
self._create_geometry(
body_name,
geom_type=p.GEOM_SPHERE,
geom_type=self.physics_client.GEOM_SPHERE,
mass=mass,
position=position,
ghost=ghost,
Expand All @@ -481,7 +486,7 @@ def _create_geometry(
Args:
body_name (str): The name of the body. Must be unique in the sim.
geom_type (int): The geometry type. See p.GEOM_<shape>.
geom_type (int): The geometry type. See self.physics_client.GEOM_<shape>.
mass (float, optional): The mass in kg. Defaults to 0.
position (x, y, z): The position of the geom. Defaults to (0, 0, 0)
ghost (bool, optional): Whether the geometry can collide. Defaults
Expand All @@ -490,20 +495,20 @@ def _create_geometry(
visual_kwargs (dict, optional): Visual kwargs. Defaults to {}.
collision_kwargs (dict, optional): Collision kwargs. Defaults to {}.
"""
baseVisualShapeIndex = p.createVisualShape(geom_type, **visual_kwargs)
baseVisualShapeIndex = self.physics_client.createVisualShape(geom_type, **visual_kwargs)
if not ghost:
baseCollisionShapeIndex = p.createCollisionShape(geom_type, **collision_kwargs)
baseCollisionShapeIndex = self.physics_client.createCollisionShape(geom_type, **collision_kwargs)
else:
baseCollisionShapeIndex = -1
self._bodies_idx[body_name] = p.createMultiBody(
self._bodies_idx[body_name] = self.physics_client.createMultiBody(
baseVisualShapeIndex=baseVisualShapeIndex,
baseCollisionShapeIndex=baseCollisionShapeIndex,
baseMass=mass,
basePosition=position,
)

if friction is not None:
p.changeDynamics(
self.physics_client.changeDynamics(
bodyUniqueId=self._bodies_idx[body_name],
linkIndex=-1,
lateralFriction=friction,
Expand Down Expand Up @@ -544,7 +549,7 @@ def set_friction(self, body, link, friction):
link (int): Link index in the body.
friction (float): Lateral friction.
"""
p.changeDynamics(
self.physics_client.changeDynamics(
bodyUniqueId=self._bodies_idx[body],
linkIndex=link,
lateralFriction=friction,
Expand Down

0 comments on commit ec7454a

Please sign in to comment.