Skip to content

bind host to zmq services #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup, find_packages
from setuptools import setup

setup(
name='simpub',
Expand Down
147 changes: 91 additions & 56 deletions simpub/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import zmq
import time
from time import sleep
from socket import socket, AF_INET, SOCK_DGRAM
from socket import SOL_SOCKET, SO_BROADCAST
import socket
from socket import AF_INET, SOCK_DGRAM, SOL_SOCKET, SO_BROADCAST
import json
import threading
from enum import Enum
import struct
from concurrent.futures import ThreadPoolExecutor, Future

from simpub.simdata import SimScene
Expand All @@ -23,10 +24,10 @@ class PortSet(int, Enum):
class TaskBase(abc.ABC):

def __init__(self):
self._running: bool = False
self.running: bool = False

def shutdown(self):
self._running = False
self.running = False
self.on_shutdown()

@abc.abstractmethod
Expand All @@ -43,56 +44,31 @@ class BroadcastTask(TaskBase):
def __init__(
self,
discovery_message: str,
host: str = "127.0.0.1",
mask: str = "255.255.255.0",
port: int = PortSet.DISCOVERY,
intervall: float = 0.5,
intervall: float = 1.0,
):
self._port = port
self._running = True
self.running = True
self._intervall = intervall
self._message = discovery_message.encode()
# calculate broadcast ip
ip_bin = struct.unpack('!I', socket.inet_aton(host))[0]
netmask_bin = struct.unpack('!I', socket.inet_aton(mask))[0]
broadcast_bin = ip_bin | ~netmask_bin & 0xFFFFFFFF
self.broadcast_ip = socket.inet_ntoa(struct.pack('!I', broadcast_bin))

def execute(self):
self._running = True
self.conn = socket(AF_INET, SOCK_DGRAM) # create UDP socket
self.running = True
self.conn = socket.socket(AF_INET, SOCK_DGRAM)
self.conn.setsockopt(SOL_SOCKET, SO_BROADCAST, 1)
while self._running:
self.conn.sendto(self._message, ("255.255.255.255", self._port))
while self.running:
self.conn.sendto(self._message, (self.broadcast_ip, self._port))
sleep(self._intervall)

def on_shutdown(self):
self._running = False


class ServerBase(abc.ABC):

def __init__(self):
self._running: bool = False
self.executor: ThreadPoolExecutor
self.futures: List[Future] = []
self.tasks: List[TaskBase] = []
self.zmqContext = zmq.Context()
self.initialize_task()

@abc.abstractmethod
def initialize_task(self):
raise NotImplementedError

def start(self):
self._running = True
self.thread = threading.Thread(target=self.thread_task)
self.thread.start()

def thread_task(self):
print("Server Tasks has been started")
with ThreadPoolExecutor(max_workers=5) as executor:
self.executor = executor
for task in self.tasks:
self.futures.append(executor.submit(task.execute))
self._running = False

@abc.abstractmethod
def shutdown(self):
raise NotImplementedError
self.running = False


class StreamTask(TaskBase):
Expand All @@ -101,24 +77,24 @@ def __init__(
self,
context: zmq.Context,
update_func: Callable[[], Dict],
host: str = "127.0.0.1",
port: int = PortSet.STREAMING,
topic: str = "SceneUpdate",
fps: int = 45,
):
self._context: zmq.Context = context
self._update_func = update_func
self._port: int = port
self._topic: str = topic
self._running: bool = False
self.running: bool = False
self._dt: float = 1 / fps
self.pub_socket: zmq.Socket = self._context.socket(zmq.PUB)
self.pub_socket.bind(f"tcp://{host}:{port}")

def execute(self):
print("Stream task has been started")
self._running = True
self.pub_socket: zmq.Socket = self._context.socket(zmq.PUB)
self.pub_socket.bind(f"tcp://*:{self._port}")
self.running = True
last = 0.0
while self._running:
while self.running:
diff = time.monotonic() - last
if diff < self._dt:
time.sleep(self._dt - diff)
Expand All @@ -138,18 +114,19 @@ class MsgService(TaskBase):
def __init__(
self,
context: zmq.Context,
host: str = "127.0.0.1",
port: int = PortSet.SERVICE,
):
self._context: zmq.Context = context
self._port: int = port
self._running: bool = False
self.running: bool = False
self._actions: Dict[str, Callable[[zmq.Socket, str], None]] = {}
self.reply_socket: zmq.Socket = self._context.socket(zmq.REP)
self.reply_socket.bind(f"tcp://{host}:{port}")

def execute(self):
self._running = True
self.reply_socket: zmq.Socket = self._context.socket(zmq.REP)
self.reply_socket.bind(f"tcp://*:{self._port}")
while self._running:
self.running = True
while self.running:
message = self.reply_socket.recv().decode()
tag, *args = message.split(":", 1)
if tag == "END":
Expand All @@ -170,14 +147,72 @@ def register_action(
self._actions[tag] = action

def on_shutdown(self):
self._running = False
self.running = False


class ServerBase(abc.ABC):

def __init__(self, host: str = "127.0.0.1"):
self.host: str = host
self.running: bool = False
self.executor: ThreadPoolExecutor
self.futures: List[Future] = []
self.tasks: List[TaskBase] = []
self.zmqContext = zmq.Context()
self.initialize_task()

@abc.abstractmethod
def initialize_task(self):
raise NotImplementedError

def start(self):
self.running = True
self.thread = threading.Thread(target=self.thread_task)
self.thread.start()

def join(self):
self.thread.join()

def thread_task(self):
print("Server Tasks has been started")
with ThreadPoolExecutor(max_workers=5) as executor:
self.executor = executor
for task in self.tasks:
self.futures.append(executor.submit(task.execute))
self.running = False

def shutdown(self):
print("Trying to shutdown server")
for task in self.tasks:
task.shutdown()
self.thread.join()
print("All the threads have been stopped")


class MsgServer(ServerBase):
def __init__(self, host: str = "127.0.0.1"):
super().__init__(host)

def initialize_task(self):
self.tasks: List[TaskBase] = []
discovery_data = {
"SERVICE": PortSet.SERVICE,
}
time_stamp = str(time.time())
discovery_message = f"SimPub:{time_stamp}:{json.dumps(discovery_data)}"
self.broadcast_task = BroadcastTask(discovery_message)
self.tasks.append(self.broadcast_task)

self.msg_service = MsgService(self.zmqContext)
self.tasks.append(self.msg_service)


class SimPublisher(ServerBase):

def __init__(
self,
sim_scene: SimScene,
host: str = "127.0.0.1",
no_rendered_objects: List[str] = None,
no_tracked_objects: List[str] = None,
) -> None:
Expand All @@ -190,7 +225,7 @@ def __init__(
self.no_tracked_objects = []
else:
self.no_tracked_objects = no_tracked_objects
super().__init__()
super().__init__(host)

def initialize_task(self):
self.tasks: List[TaskBase] = []
Expand Down
4 changes: 3 additions & 1 deletion simpub/sim/mj_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def __init__(
mj_model,
mj_data,
mjcf_path: str,
host: str = "localhost",
no_rendered_objects: List[str] = None,
no_tracked_objects: List[str] = None,
) -> None:
Expand All @@ -26,7 +27,8 @@ def __init__(
self,
sim_scene,
no_rendered_objects,
no_tracked_objects
no_tracked_objects,
host,
)
for child in self.sim_scene.root.children:
self.set_update_objects(child)
Expand Down
2 changes: 2 additions & 0 deletions simpub/sim/sf_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class SFPublisher(MujocoPublisher):
def __init__(
self,
sf_mj_sim: MjScene,
host: str = "localhost",
no_rendered_objects: List[str] = None,
no_tracked_objects: List[str] = None,
) -> None:
Expand All @@ -74,6 +75,7 @@ def __init__(
SimPublisher.__init__(
self,
self.parser.parse(),
host,
no_rendered_objects,
no_tracked_objects
)
Expand Down