Skip to content

Feat/fancy gym #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions demos/fancy_gym/BoxPushingDense.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import fancy_gym
import time

from simpub.sim.fancy_gym import FancyGymPublisher

env_name = "BoxPushingDense-v0"

env = fancy_gym.make(env_name, seed=1)
obs = env.reset()

publisher = FancyGymPublisher(env_name, env, "127.0.0.1")

for i in range(1000):
action = env.action_space.sample()
obs, reward, done, info = env.step(action)
env.render()
time.sleep(0.01)
if done:
obs = env.reset()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name='simpub',
version='0.1',
install_requires=["zmq", "trimesh", "pillow", "numpy", "scipy"],
install_requires=["zmq", "trimesh", "pillow", "numpy", "scipy", "colorama"],
include_package_data=True,
packages=['simpub', 'simpub.parser', 'simpub.sim']
)
16 changes: 10 additions & 6 deletions simpub/core/net_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@


class ServerPort(int, enum.Enum):
# ServerPort and ClientPort need using .value to get the port number
# which is not supposed to be used in this way
DISCOVERY = 7720
SERVICE = 7721
TOPIC = 7722
Expand Down Expand Up @@ -169,10 +171,10 @@ def __init__(
self.sub_socket_dict: Dict[IPAddress, zmq.Socket] = {}
# publisher
self.pub_socket = self.zmq_context.socket(zmq.PUB)
self.pub_socket.bind(f"tcp://{host_ip}:{ServerPort.TOPIC}")
self.pub_socket.bind(f"tcp://{host_ip}:{ServerPort.TOPIC.value}")
# service
self.service_socket = self.zmq_context.socket(zmq.REP)
self.service_socket.bind(f"tcp://{host_ip}:{ServerPort.SERVICE}")
self.service_socket.bind(f"tcp://{host_ip}:{ServerPort.SERVICE.value}")
self.service_list: Dict[str, Service] = {}
# message for broadcasting
self.local_info = HostInfo()
Expand All @@ -192,7 +194,7 @@ def start_server_thread(self) -> None:
Start a thread for service.

Args:
block (bool, optional): main thread stop running and
block (bool, optional): main thread stop running and
wait for server thread. Defaults to False.
"""
self.server_thread = threading.Thread(target=self.start_event_loop)
Expand All @@ -209,8 +211,8 @@ def start_event_loop(self):
self.submit_task(self.service_loop)
self.loop.run_forever()

def submit_task(self, task: Callable, *args):
asyncio.run_coroutine_threadsafe(task(*args), self.loop)
def submit_task(self, task: Callable, *args) -> asyncio.Future:
return asyncio.run_coroutine_threadsafe(task(*args), self.loop)

def stop_server(self):
if self.loop.is_running():
Expand Down Expand Up @@ -260,7 +262,9 @@ async def broadcast_loop(self):
broadcast_ip = socket.inet_ntoa(struct.pack('!I', broadcast_bin))
while self.running:
msg = f"SimPub:{_id}:{json.dumps(local_info)}"
_socket.sendto(msg.encode(), (broadcast_ip, ServerPort.DISCOVERY))
_socket.sendto(
msg.encode(), (broadcast_ip, ServerPort.DISCOVERY.value)
)
await asycnc_sleep(0.1)
logger.info("Broadcasting has been stopped")

Expand Down
19 changes: 14 additions & 5 deletions simpub/parser/mjcf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,15 +132,24 @@ def sphere2unity_scale(scale: List[float]) -> List[float]:


def cylinder2unity_scale(scale: List[float]) -> List[float]:
# if len(scale) == 3:
# return list(map(abs, [scale[0], scale[1], scale[0]]))
# else:
# return list(map(abs, [scale[0] * 2, scale[1], scale[0] * 2]))
if len(scale) == 3:
return list(map(abs, [scale[0], scale[1], scale[0]]))
else:
return list(map(abs, [scale[0] * 2, scale[1], scale[0] * 2]))

if len(scale) == 2:
return list(map(abs, [scale[0], scale[1], scale[0]]))
elif len(scale) == 1:
return list(map(abs, [scale[0] * 2, scale[0] * 2, scale[0] * 2]))

def capsule2unity_scale(scale: List[float]) -> List[float]:
assert len(scale) == 3, "Only support scale with three components."
return list(map(abs, [scale[0], scale[1], scale[0]]))
# assert len(scale) == 3, "Only support scale with three components."
# return list(map(abs, [scale[0], scale[1], scale[0]]))
if len(scale) == 2:
return list(map(abs, [scale[0], scale[1], scale[0]]))
elif len(scale) == 1:
return list(map(abs, [scale[0] * 2, scale[0] * 2, scale[0] * 2]))


ScaleMap: Dict[str, Callable] = {
Expand Down
36 changes: 36 additions & 0 deletions simpub/sim/fancy_gym.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List
from gym.envs.mujoco.mujoco_env import MujocoEnv
import fancy_gym
import os

from .mj_publisher import MujocoPublisher

FancyGymPath = os.path.dirname(fancy_gym.__file__)
FancyGymEnvPathDict = {
"BoxPushingDense-v0":
os.path.join(
FancyGymPath,
"envs/mujoco/box_pushing/assets/box_pushing.xml"
)
}


class FancyGymPublisher(MujocoPublisher):

def __init__(
self,
env_name: str,
mj_env: MujocoEnv,
host: str = "127.0.0.1",
no_rendered_objects: List[str] = None,
no_tracked_objects: List[str] = None,
) -> None:

super().__init__(
mj_env.model,
mj_env.data,
FancyGymEnvPathDict[env_name],
host,
no_rendered_objects,
no_tracked_objects,
)
2 changes: 1 addition & 1 deletion simpub/sim/mj_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def __init__(
self.tracked_obj_trans: Dict[str, np.ndarray] = dict()
super().__init__(
sim_scene,
host,
no_rendered_objects,
no_tracked_objects,
host,
)
for child in self.sim_scene.root.children:
self.set_update_objects(child)
Expand Down
19 changes: 15 additions & 4 deletions simpub/xr_device/xr_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,35 @@ async def wait_for_connection(self):
break
await asycnc_sleep(0.01)
self.req_socket.connect(
f"tcp://{self.client_info['ip']}:{ClientPort.SERVICE}")
f"tcp://{self.client_info['ip']}:{ClientPort.SERVICE.value}")
self.sub_socket.connect(
f"tcp://{self.client_info['ip']}:{ClientPort.TOPIC}")
f"tcp://{self.client_info['ip']}:{ClientPort.TOPIC.value}")
self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.manager.submit_task(self.subscribe_loop)

def register_topic_callback(self, topic: str, callback: Callable):
self.sub_topic_callback[topic] = callback

def request(self, service: str, req: str) -> str:
future = self.manager.submit_task(
self.request_async, service, req
)
try:
result = future.result()
return result
except Exception as e:
logger.error(f"Find a new when waiting for a response: {e}")
return ""

async def request_async(self, service: str, req: str) -> str:
if self.client_info is None:
logger.error(f"Device {self.device} is not connected")
return ""
if service not in self.client_info["services"]:
logger.error(f"\"{service}\" Service is not available")
return ""
self.req_socket.send_string(f"{service}:{req}")
return self.req_socket.recv_string()
await self.req_socket.send_string(f"{service}:{req}")
return await self.req_socket.recv_string()

async def subscribe_loop(self):
try:
Expand Down